1
0
Fork 0

lint(api): start renaming inversions to embeddings in code

This commit is contained in:
Sean Sube 2023-09-24 18:15:58 -05:00
parent cdb09d2b44
commit e338fcd0e0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 38 additions and 23 deletions

View File

@ -52,7 +52,7 @@ class BlendImg2ImgStage(BaseStage):
params, params,
pipe_type, pipe_type,
worker.get_device(), worker.get_device(),
inversions=inversions, embeddings=inversions,
loras=loras, loras=loras,
) )

View File

@ -79,7 +79,7 @@ class SourceTxt2ImgStage(BaseStage):
params, params,
pipe_type, pipe_type,
worker.get_device(), worker.get_device(),
inversions=inversions, embeddings=inversions,
loras=loras, loras=loras,
) )

View File

@ -56,7 +56,7 @@ class UpscaleOutpaintStage(BaseStage):
params, params,
pipe_type, pipe_type,
worker.get_device(), worker.get_device(),
inversions=inversions, embeddings=inversions,
loras=loras, loras=loras,
) )

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import path from os import path
from typing import Any, List, Optional, Tuple from typing import Any, List, Literal, Optional, Tuple
from onnx import load_model from onnx import load_model
from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline, from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline,
@ -114,11 +114,11 @@ def load_pipeline(
params: ImageParams, params: ImageParams,
pipeline: str, pipeline: str,
device: DeviceParams, device: DeviceParams,
inversions: Optional[List[Tuple[str, float]]] = None, embeddings: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None,
model: Optional[str] = None, model: Optional[str] = None,
): ):
inversions = inversions or [] embeddings = embeddings or []
loras = loras or [] loras = loras or []
model = model or params.model model = model or params.model
@ -132,7 +132,7 @@ def load_pipeline(
device.device, device.device,
device.provider, device.provider,
control_key, control_key,
inversions, embeddings,
loras, loras,
) )
scheduler_key = (params.scheduler, model) scheduler_key = (params.scheduler, model)
@ -189,9 +189,9 @@ def load_pipeline(
components.update(control_components) components.update(control_components)
unet_type = "cnet" unet_type = "cnet"
# Textual Inversion blending # load various pipeline components
encoder_components = load_text_encoders( encoder_components = load_text_encoders(
server, device, model, inversions, loras, torch_dtype, params server, device, model, embeddings, loras, torch_dtype, params
) )
components.update(encoder_components) components.update(encoder_components)
@ -277,7 +277,7 @@ def load_pipeline(
return pipe return pipe
def load_controlnet(server, device, params): def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams):
cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx") cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx")
logger.debug("loading ControlNet weights from %s", cnet_path) logger.debug("loading ControlNet weights from %s", cnet_path)
components = {} components = {}
@ -292,7 +292,13 @@ def load_controlnet(server, device, params):
def load_text_encoders( def load_text_encoders(
server, device, model: str, inversions, loras, torch_dtype, params server: ServerContext,
device: DeviceParams,
model: str,
embeddings: Optional[List[Tuple[str, float]]],
loras: Optional[List[Tuple[str, float]]],
torch_dtype,
params: ImageParams,
): ):
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
model, model,
@ -310,13 +316,13 @@ def load_text_encoders(
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL)) text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
# blend embeddings, if any # blend embeddings, if any
if inversions is not None and len(inversions) > 0: if embeddings is not None and len(embeddings) > 0:
inversion_names, inversion_weights = zip(*inversions) embedding_names, embedding_weights = zip(*embeddings)
inversion_models = [ embedding_models = [
path.join(server.model_path, "inversion", name) for name in inversion_names path.join(server.model_path, "inversion", name) for name in embedding_names
] ]
logger.debug( logger.debug(
"blending base model %s with embeddings from %s", model, inversion_models "blending base model %s with embeddings from %s", model, embedding_models
) )
# TODO: blend text_encoder_2 as well # TODO: blend text_encoder_2 as well
@ -326,10 +332,10 @@ def load_text_encoders(
tokenizer, tokenizer,
list( list(
zip( zip(
inversion_models, embedding_models,
inversion_weights, embedding_weights,
inversion_names, embedding_names,
[None] * len(inversion_models), [None] * len(embedding_models),
) )
), ),
) )
@ -340,7 +346,7 @@ def load_text_encoders(
lora_models = [ lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names path.join(server.model_path, "lora", name) for name in lora_names
] ]
logger.info("blending base model %s with LoRA models: %s", model, lora_models) logger.info("blending base model %s with LoRAs from %s", model, lora_models)
# blend and load text encoder # blend and load text encoder
text_encoder = blend_loras( text_encoder = blend_loras(
@ -411,7 +417,14 @@ def load_text_encoders(
return components return components
def load_unet(server, device, model, loras, unet_type, params): def load_unet(
server: ServerContext,
device: DeviceParams,
model: str,
loras: List[Tuple[str, float]],
unet_type: Literal["cnet", "unet"],
params: ImageParams,
):
components = {} components = {}
unet = load_model(path.join(model, unet_type, ONNX_MODEL)) unet = load_model(path.join(model, unet_type, ONNX_MODEL))
@ -457,7 +470,9 @@ def load_unet(server, device, model, loras, unet_type, params):
return components return components
def load_vae(server, device, model, params): def load_vae(
server: ServerContext, device: DeviceParams, model: str, params: ImageParams
):
# one or more VAE models need to be loaded # one or more VAE models need to be loaded
vae = path.join(model, "vae", ONNX_MODEL) vae = path.join(model, "vae", ONNX_MODEL)
vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL) vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)