lint(api): start renaming inversions to embeddings in code
This commit is contained in:
parent
cdb09d2b44
commit
e338fcd0e0
|
@ -52,7 +52,7 @@ class BlendImg2ImgStage(BaseStage):
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
worker.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
embeddings=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ class SourceTxt2ImgStage(BaseStage):
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
worker.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
embeddings=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,7 @@ class UpscaleOutpaintStage(BaseStage):
|
||||||
params,
|
params,
|
||||||
pipe_type,
|
pipe_type,
|
||||||
worker.get_device(),
|
worker.get_device(),
|
||||||
inversions=inversions,
|
embeddings=inversions,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Literal, Optional, Tuple
|
||||||
|
|
||||||
from onnx import load_model
|
from onnx import load_model
|
||||||
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
|
||||||
|
@ -114,11 +114,11 @@ def load_pipeline(
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
pipeline: str,
|
pipeline: str,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
embeddings: Optional[List[Tuple[str, float]]] = None,
|
||||||
loras: Optional[List[Tuple[str, float]]] = None,
|
loras: Optional[List[Tuple[str, float]]] = None,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
):
|
):
|
||||||
inversions = inversions or []
|
embeddings = embeddings or []
|
||||||
loras = loras or []
|
loras = loras or []
|
||||||
model = model or params.model
|
model = model or params.model
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ def load_pipeline(
|
||||||
device.device,
|
device.device,
|
||||||
device.provider,
|
device.provider,
|
||||||
control_key,
|
control_key,
|
||||||
inversions,
|
embeddings,
|
||||||
loras,
|
loras,
|
||||||
)
|
)
|
||||||
scheduler_key = (params.scheduler, model)
|
scheduler_key = (params.scheduler, model)
|
||||||
|
@ -189,9 +189,9 @@ def load_pipeline(
|
||||||
components.update(control_components)
|
components.update(control_components)
|
||||||
unet_type = "cnet"
|
unet_type = "cnet"
|
||||||
|
|
||||||
# Textual Inversion blending
|
# load various pipeline components
|
||||||
encoder_components = load_text_encoders(
|
encoder_components = load_text_encoders(
|
||||||
server, device, model, inversions, loras, torch_dtype, params
|
server, device, model, embeddings, loras, torch_dtype, params
|
||||||
)
|
)
|
||||||
components.update(encoder_components)
|
components.update(encoder_components)
|
||||||
|
|
||||||
|
@ -277,7 +277,7 @@ def load_pipeline(
|
||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(server, device, params):
|
def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams):
|
||||||
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
|
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
|
||||||
logger.debug("loading ControlNet weights from %s", cnet_path)
|
logger.debug("loading ControlNet weights from %s", cnet_path)
|
||||||
components = {}
|
components = {}
|
||||||
|
@ -292,7 +292,13 @@ def load_controlnet(server, device, params):
|
||||||
|
|
||||||
|
|
||||||
def load_text_encoders(
|
def load_text_encoders(
|
||||||
server, device, model: str, inversions, loras, torch_dtype, params
|
server: ServerContext,
|
||||||
|
device: DeviceParams,
|
||||||
|
model: str,
|
||||||
|
embeddings: Optional[List[Tuple[str, float]]],
|
||||||
|
loras: Optional[List[Tuple[str, float]]],
|
||||||
|
torch_dtype,
|
||||||
|
params: ImageParams,
|
||||||
):
|
):
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
model,
|
model,
|
||||||
|
@ -310,13 +316,13 @@ def load_text_encoders(
|
||||||
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
||||||
|
|
||||||
# blend embeddings, if any
|
# blend embeddings, if any
|
||||||
if inversions is not None and len(inversions) > 0:
|
if embeddings is not None and len(embeddings) > 0:
|
||||||
inversion_names, inversion_weights = zip(*inversions)
|
embedding_names, embedding_weights = zip(*embeddings)
|
||||||
inversion_models = [
|
embedding_models = [
|
||||||
path.join(server.model_path, "inversion", name) for name in inversion_names
|
path.join(server.model_path, "inversion", name) for name in embedding_names
|
||||||
]
|
]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"blending base model %s with embeddings from %s", model, inversion_models
|
"blending base model %s with embeddings from %s", model, embedding_models
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: blend text_encoder_2 as well
|
# TODO: blend text_encoder_2 as well
|
||||||
|
@ -326,10 +332,10 @@ def load_text_encoders(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
list(
|
list(
|
||||||
zip(
|
zip(
|
||||||
inversion_models,
|
embedding_models,
|
||||||
inversion_weights,
|
embedding_weights,
|
||||||
inversion_names,
|
embedding_names,
|
||||||
[None] * len(inversion_models),
|
[None] * len(embedding_models),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -340,7 +346,7 @@ def load_text_encoders(
|
||||||
lora_models = [
|
lora_models = [
|
||||||
path.join(server.model_path, "lora", name) for name in lora_names
|
path.join(server.model_path, "lora", name) for name in lora_names
|
||||||
]
|
]
|
||||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
logger.info("blending base model %s with LoRAs from %s", model, lora_models)
|
||||||
|
|
||||||
# blend and load text encoder
|
# blend and load text encoder
|
||||||
text_encoder = blend_loras(
|
text_encoder = blend_loras(
|
||||||
|
@ -411,7 +417,14 @@ def load_text_encoders(
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
def load_unet(server, device, model, loras, unet_type, params):
|
def load_unet(
|
||||||
|
server: ServerContext,
|
||||||
|
device: DeviceParams,
|
||||||
|
model: str,
|
||||||
|
loras: List[Tuple[str, float]],
|
||||||
|
unet_type: Literal["cnet", "unet"],
|
||||||
|
params: ImageParams,
|
||||||
|
):
|
||||||
components = {}
|
components = {}
|
||||||
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
|
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
|
||||||
|
|
||||||
|
@ -457,7 +470,9 @@ def load_unet(server, device, model, loras, unet_type, params):
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
def load_vae(server, device, model, params):
|
def load_vae(
|
||||||
|
server: ServerContext, device: DeviceParams, model: str, params: ImageParams
|
||||||
|
):
|
||||||
# one or more VAE models need to be loaded
|
# one or more VAE models need to be loaded
|
||||||
vae = path.join(model, "vae", ONNX_MODEL)
|
vae = path.join(model, "vae", ONNX_MODEL)
|
||||||
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)
|
||||||
|
|
Loading…
Reference in New Issue