From e338fcd0e0ac0725f4baef8ac81df6705849000a Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 24 Sep 2023 18:15:58 -0500 Subject: [PATCH] lint(api): start renaming inversions to embeddings in code --- api/onnx_web/chain/blend_img2img.py | 2 +- api/onnx_web/chain/source_txt2img.py | 2 +- api/onnx_web/chain/upscale_outpaint.py | 2 +- api/onnx_web/diffusers/load.py | 55 ++++++++++++++++---------- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 0d5ad28e..af181c10 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -52,7 +52,7 @@ class BlendImg2ImgStage(BaseStage): params, pipe_type, worker.get_device(), - inversions=inversions, + embeddings=inversions, loras=loras, ) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index b4dd7c24..5448364e 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -79,7 +79,7 @@ class SourceTxt2ImgStage(BaseStage): params, pipe_type, worker.get_device(), - inversions=inversions, + embeddings=inversions, loras=loras, ) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 67f7ca0a..78d32077 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -56,7 +56,7 @@ class UpscaleOutpaintStage(BaseStage): params, pipe_type, worker.get_device(), - inversions=inversions, + embeddings=inversions, loras=loras, ) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 37800e71..1198d85d 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -1,6 +1,6 @@ from logging import getLogger from os import path -from typing import Any, List, Optional, Tuple +from typing import Any, List, Literal, Optional, Tuple from onnx import load_model from optimum.onnxruntime import ( # ORTStableDiffusionXLInpaintPipeline, @@ -114,11 +114,11 @@ def load_pipeline( params: ImageParams, pipeline: str, device: DeviceParams, - inversions: Optional[List[Tuple[str, float]]] = None, + embeddings: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None, model: Optional[str] = None, ): - inversions = inversions or [] + embeddings = embeddings or [] loras = loras or [] model = model or params.model @@ -132,7 +132,7 @@ def load_pipeline( device.device, device.provider, control_key, - inversions, + embeddings, loras, ) scheduler_key = (params.scheduler, model) @@ -189,9 +189,9 @@ def load_pipeline( components.update(control_components) unet_type = "cnet" - # Textual Inversion blending + # load various pipeline components encoder_components = load_text_encoders( - server, device, model, inversions, loras, torch_dtype, params + server, device, model, embeddings, loras, torch_dtype, params ) components.update(encoder_components) @@ -277,7 +277,7 @@ def load_pipeline( return pipe -def load_controlnet(server, device, params): +def load_controlnet(server: ServerContext, device: DeviceParams, params: ImageParams): cnet_path = path.join(server.model_path, "control", f"{params.control.name}.onnx") logger.debug("loading ControlNet weights from %s", cnet_path) components = {} @@ -292,7 +292,13 @@ def load_controlnet(server, device, params): def load_text_encoders( - server, device, model: str, inversions, loras, torch_dtype, params + server: ServerContext, + device: DeviceParams, + model: str, + embeddings: Optional[List[Tuple[str, float]]], + loras: Optional[List[Tuple[str, float]]], + torch_dtype, + params: ImageParams, ): tokenizer = CLIPTokenizer.from_pretrained( model, @@ -310,13 +316,13 @@ def load_text_encoders( text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL)) # blend embeddings, if any - if inversions is not None and len(inversions) > 0: - inversion_names, inversion_weights = zip(*inversions) - inversion_models = [ - path.join(server.model_path, "inversion", name) for name in inversion_names + if embeddings is not None and len(embeddings) > 0: + embedding_names, embedding_weights = zip(*embeddings) + embedding_models = [ + path.join(server.model_path, "inversion", name) for name in embedding_names ] logger.debug( - "blending base model %s with embeddings from %s", model, inversion_models + "blending base model %s with embeddings from %s", model, embedding_models ) # TODO: blend text_encoder_2 as well @@ -326,10 +332,10 @@ def load_text_encoders( tokenizer, list( zip( - inversion_models, - inversion_weights, - inversion_names, - [None] * len(inversion_models), + embedding_models, + embedding_weights, + embedding_names, + [None] * len(embedding_models), ) ), ) @@ -340,7 +346,7 @@ def load_text_encoders( lora_models = [ path.join(server.model_path, "lora", name) for name in lora_names ] - logger.info("blending base model %s with LoRA models: %s", model, lora_models) + logger.info("blending base model %s with LoRAs from %s", model, lora_models) # blend and load text encoder text_encoder = blend_loras( @@ -411,7 +417,14 @@ def load_text_encoders( return components -def load_unet(server, device, model, loras, unet_type, params): +def load_unet( + server: ServerContext, + device: DeviceParams, + model: str, + loras: List[Tuple[str, float]], + unet_type: Literal["cnet", "unet"], + params: ImageParams, +): components = {} unet = load_model(path.join(model, unet_type, ONNX_MODEL)) @@ -457,7 +470,9 @@ def load_unet(server, device, model, loras, unet_type, params): return components -def load_vae(server, device, model, params): +def load_vae( + server: ServerContext, device: DeviceParams, model: str, params: ImageParams +): # one or more VAE models need to be loaded vae = path.join(model, "vae", ONNX_MODEL) vae_decoder = path.join(model, "vae_decoder", ONNX_MODEL)