From a7f77a033dce5c763093de46f7b28acbcc499a17 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 15 Mar 2023 08:30:31 -0500 Subject: [PATCH] feat(api): parse LoRA weights from prompt --- api/onnx_web/convert/diffusion/lora.py | 2 +- api/onnx_web/diffusers/load.py | 86 +++++++++----------------- api/onnx_web/diffusers/run.py | 3 +- api/onnx_web/diffusers/utils.py | 19 +++++- 4 files changed, 50 insertions(+), 60 deletions(-) diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 7add40e6..47cda3fb 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -61,7 +61,7 @@ def fix_node_name(key: str): def merge_lora( base_name: str, - lora_names: str, + lora_names: List[str], dest_type: Literal["text_encoder", "unet"], lora_weights: "np.NDArray[np.float64]" = None, ): diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 5fe3feba..26316c5f 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -37,7 +37,7 @@ try: except ImportError: from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler -from ..convert.diffusion.lora import buffer_external_data_tensors, merge_lora +from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors from ..params import DeviceParams, Size from ..server import ServerContext from ..utils import run_gc @@ -107,26 +107,6 @@ def get_tile_latents( return full_latents[:, :, y:yt, x:xt] -def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]: - remaining_prompt = prompt - lora_expr = compile(r"\") - - loras = [] - next_match = lora_expr.search(remaining_prompt) - while next_match is not None: - logger.debug("found LoRA token in prompt: %s", next_match) - name, weight = next_match.groups() - loras.append(name) - # remove this match and look for another - remaining_prompt = ( - remaining_prompt[: next_match.start()] - + remaining_prompt[next_match.end() :] - ) - next_match = lora_expr.search(remaining_prompt) - - return (remaining_prompt, loras) - - def optimize_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, @@ -177,7 +157,7 @@ def load_pipeline( device: DeviceParams, lpw: bool, inversion: Optional[str], - loras: Optional[List[str]] = None, + loras: Optional[List[Tuple[str, float]]] = None, ): loras = loras or [] pipe_key = ( @@ -247,46 +227,38 @@ def load_pipeline( ) # test LoRA blending - lora_models = [ - path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras - ] + lora_names, lora_weights = zip(*loras) + lora_models = [path.join(server.model_path, "lora", f"{name}.safetensors") for name in lora_names] logger.info("blending base model %s with LoRA models: %s", model, lora_models) - # blend and load text encoder - blended_text_encoder = merge_lora( - path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder" - ) - (text_encoder_model, text_encoder_data) = buffer_external_data_tensors( - blended_text_encoder - ) - text_encoder_names, text_encoder_values = zip(*text_encoder_data) - text_encoder_opts = SessionOptions() - text_encoder_opts.add_external_initializers( - list(text_encoder_names), list(text_encoder_values) - ) - components["text_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - text_encoder_model.SerializeToString(), - provider=device.ort_provider(), - sess_options=text_encoder_opts, + if len(lora_models) > 0: + # blend and load text encoder + blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder", lora_weights=lora_weights) + (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) + text_encoder_names, text_encoder_values = zip(*text_encoder_data) + text_encoder_opts = SessionOptions() + text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) + components["text_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + text_encoder_model.SerializeToString(), + provider=device.ort_provider(), + sess_options=text_encoder_opts, + ) ) - ) - # blend and load unet - blended_unet = merge_lora( - path.join(model, "unet", "model.onnx"), lora_models, "unet" - ) - (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) - unet_names, unet_values = zip(*unet_data) - unet_opts = SessionOptions() - unet_opts.add_external_initializers(list(unet_names), list(unet_values)) - components["unet"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - unet_model.SerializeToString(), - provider=device.ort_provider(), - sess_options=unet_opts, + # blend and load unet + blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet", lora_weights=lora_weights) + (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) + unet_names, unet_values = zip(*unet_data) + unet_opts = SessionOptions() + unet_opts.add_external_initializers(list(unet_names), list(unet_values)) + components["unet"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + unet_model.SerializeToString(), + provider=device.ort_provider(), + sess_options=unet_opts, + ) ) - ) pipe = pipeline.from_pretrained( model, diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 3d5cb935..dd7320ee 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -14,7 +14,8 @@ from ..server import ServerContext from ..upscale import run_upscale_correction from ..utils import run_gc from ..worker import WorkerContext -from .load import get_latents_from_seed, get_loras_from_prompt, load_pipeline +from .load import get_latents_from_seed, load_pipeline +from .utils import get_loras_from_prompt logger = getLogger(__name__) diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 66e97f69..c122225a 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -1,7 +1,7 @@ from logging import getLogger from math import ceil from re import compile -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np from diffusers import OnnxStableDiffusionPipeline @@ -128,3 +128,20 @@ def expand_prompt( logger.debug("expanded prompt shape: %s", prompt_embeds.shape) return prompt_embeds + + +def get_loras_from_prompt(prompt: str) -> Tuple[str, List[Tuple[str, float]]]: + remaining_prompt = prompt + lora_expr = compile(r"\") + + loras = [] + next_match = lora_expr.search(remaining_prompt) + while next_match is not None: + logger.debug("found LoRA token in prompt: %s", next_match) + name, weight = next_match.groups() + loras.append((name, float(weight))) + # remove this match and look for another + remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():] + next_match = lora_expr.search(remaining_prompt) + + return (remaining_prompt, loras)