1
0
Fork 0

lint(api): start renaming inversions to embeddings in code

This commit is contained in:
Sean Sube 2023-09-24 18:15:58 -05:00
parent cdb09d2b44
commit e338fcd0e0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 38 additions and 23 deletions

View File

@ -52,7 +52,7 @@ class BlendImg2ImgStage(BaseStage):
params,
pipe_type,
worker.get_device(),
inversions=inversions,
embeddings=inversions,
loras=loras,
)

View File

@ -79,7 +79,7 @@ class SourceTxt2ImgStage(BaseStage):
params,
pipe_type,
worker.get_device(),
inversions=inversions,
embeddings=inversions,
loras=loras,
)

View File

@ -56,7 +56,7 @@ class UpscaleOutpaintStage(BaseStage):
params,
pipe_type,
worker.get_device(),
inversions=inversions,
embeddings=inversions,
loras=loras,
)

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import path
from typing import Any, List, Optional, Tuple
from typing import Any, List, Literal, Optional, Tuple
from onnx import load_model
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
@ -114,11 +114,11 @@ def load_pipeline(
params: ImageParams,
pipeline: str,
device: DeviceParams,
inversions: Optional[List[Tuple[str, float]]] = None,
embeddings: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
model: Optional[str] = None,
):
inversions = inversions or []
embeddings = embeddings or []
loras = loras or []
model = model or params.model
@ -132,7 +132,7 @@ def load_pipeline(
device.device,
device.provider,
control_key,
inversions,
embeddings,
loras,
)
scheduler_key = (params.scheduler, model)
@ -189,9 +189,9 @@ def load_pipeline(
components.update(control_components)
unet_type = "cnet"
# Textual Inversion blending
# load various pipeline components
encoder_components = load_text_encoders(
server, device, model, inversions, loras, torch_dtype, params
server, device, model, embeddings, loras, torch_dtype, params
)
components.update(encoder_components)
@ -277,7 +277,7 @@ def load_pipeline(
return pipe
def load_controlnet(server, device, params):
def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams):
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
logger.debug("loading ControlNet weights from %s", cnet_path)
components = {}
@ -292,7 +292,13 @@ def load_controlnet(server, device, params):
def load_text_encoders(
server, device, model: str, inversions, loras, torch_dtype, params
server: ServerContext,
device: DeviceParams,
model: str,
embeddings: Optional[List[Tuple[str, float]]],
loras: Optional[List[Tuple[str, float]]],
torch_dtype,
params: ImageParams,
):
tokenizer = CLIPTokenizer.from_pretrained(
model,
@ -310,13 +316,13 @@ def load_text_encoders(
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
# blend embeddings, if any
if inversions is not None and len(inversions) > 0:
inversion_names, inversion_weights = zip(*inversions)
inversion_models = [
path.join(server.model_path, "inversion", name) for name in inversion_names
if embeddings is not None and len(embeddings) > 0:
embedding_names, embedding_weights = zip(*embeddings)
embedding_models = [
path.join(server.model_path, "inversion", name) for name in embedding_names
]
logger.debug(
"blending base model %s with embeddings from %s", model, inversion_models
"blending base model %s with embeddings from %s", model, embedding_models
)
# TODO: blend text_encoder_2 as well
@ -326,10 +332,10 @@ def load_text_encoders(
tokenizer,
list(
zip(
inversion_models,
inversion_weights,
inversion_names,
[None] * len(inversion_models),
embedding_models,
embedding_weights,
embedding_names,
[None] * len(embedding_models),
)
),
)
@ -340,7 +346,7 @@ def load_text_encoders(
lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names
]
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
logger.info("blending base model %s with LoRAs from %s", model, lora_models)
# blend and load text encoder
text_encoder = blend_loras(
@ -411,7 +417,14 @@ def load_text_encoders(
return components
def load_unet(server, device, model, loras, unet_type, params):
def load_unet(
server: ServerContext,
device: DeviceParams,
model: str,
loras: List[Tuple[str, float]],
unet_type: Literal["cnet", "unet"],
params: ImageParams,
):
components = {}
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
@ -457,7 +470,9 @@ def load_unet(server, device, model, loras, unet_type, params):
return components
def load_vae(server, device, model, params):
def load_vae(
server: ServerContext, device: DeviceParams, model: str, params: ImageParams
):
# one or more VAE models need to be loaded
vae = path.join(model, "vae", ONNX_MODEL)
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)