diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 4f5f9ab4..32ea7990 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -8,7 +8,6 @@ import torch from onnx import ModelProto, load, numpy_helper from onnx.checker import check_model from onnx.external_data_helper import ( - ExternalDataInfo, convert_model_to_external_data, set_external_data, write_external_data_tensors, @@ -61,7 +60,6 @@ def fix_node_name(key: str): def merge_lora( base_name: str, lora_names: str, - dest_path: str, dest_type: Literal["text_encoder", "unet"], lora_weights: "np.NDArray[np.float64]" = None, ): @@ -227,7 +225,7 @@ if __name__ == "__main__": args = parser.parse_args() logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights) - blend_model = merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights) + blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights) if args.dest is None or args.dest == "" or args.dest == "ort": # convert to external data and save to memory (bare_model, external_data) = buffer_external_data_tensors(blend_model) @@ -247,3 +245,7 @@ if __name__ == "__main__": model_file.write(bare_model.SerializeToString()) logger.info("successfully saved blended model: %s", dest_file) + + check_model(dest_file) + + logger.info("checked blended model") diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 6c95364f..06136a1c 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -1,6 +1,6 @@ from logging import getLogger from os import path -from typing import Any, Optional, Tuple +from typing import Any, List, Optional, Tuple import numpy as np from diffusers import ( @@ -106,6 +106,13 @@ def get_tile_latents( return full_latents[:, :, y:yt, x:xt] +def get_loras_from_prompt(prompt: str) -> List[str]: + return [ + "arch", + "glass", + ] + + def optimize_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, @@ -156,7 +163,9 @@ def load_pipeline( device: DeviceParams, lpw: bool, inversion: Optional[str], + loras: Optional[List[str]] = None, ): + loras = loras or [] pipe_key = ( pipeline.__name__, model, @@ -164,6 +173,7 @@ def load_pipeline( device.provider, lpw, inversion, + loras, ) scheduler_key = (scheduler_name, model) scheduler_type = get_pipeline_schedulers()[scheduler_name] @@ -223,26 +233,36 @@ def load_pipeline( ) # test LoRA blending - lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in [ - "arch", - "glass", - ]] + lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras] + logger.info("blending base model %s with LoRA models: %s", model, lora_models) - logger.info("blending text encoder with LoRA models: %s", lora_models) - blended_text_encoder = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), lora_models, None, "text_encoder") + # blend and load text encoder + blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, None, "text_encoder") (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_opts = SessionOptions() text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) - components["text_encoder"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(text_encoder_model.SerializeToString(), provider=device.ort_provider(), sess_options=text_encoder_opts)) + components["text_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + text_encoder_model.SerializeToString(), + provider=device.ort_provider(), + sess_options=text_encoder_opts, + ) + ) - logger.info("blending unet with LoRA models: %s", lora_models) - blended_unet = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/unet/model.onnx"), lora_models, None, "unet") + # blend and load unet + blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, None, "unet") (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) unet_names, unet_values = zip(*unet_data) unet_opts = SessionOptions() unet_opts.add_external_initializers(list(unet_names), list(unet_values)) - components["unet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(unet_model.SerializeToString(), provider=device.ort_provider(), sess_options=unet_opts)) + components["unet"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + unet_model.SerializeToString(), + provider=device.ort_provider(), + sess_options=unet_opts, + ) + ) pipe = pipeline.from_pretrained( model, diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index edb286fc..95c3380c 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -14,7 +14,7 @@ from ..server import ServerContext from ..upscale import run_upscale_correction from ..utils import run_gc from ..worker import WorkerContext -from .load import get_latents_from_seed, load_pipeline +from .load import get_latents_from_seed, get_loras_from_prompt, load_pipeline logger = getLogger(__name__) @@ -28,6 +28,7 @@ def run_txt2img_pipeline( upscale: UpscaleParams, ) -> None: latents = get_latents_from_seed(params.seed, size, batch=params.batch) + loras = get_loras_from_prompt(params.prompt) pipe = load_pipeline( server, OnnxStableDiffusionPipeline, @@ -36,6 +37,7 @@ def run_txt2img_pipeline( job.get_device(), params.lpw, params.inversion, + loras, ) progress = job.get_progress_callback() @@ -99,6 +101,7 @@ def run_img2img_pipeline( source: Image.Image, strength: float, ) -> None: + loras = get_loras_from_prompt(params.prompt) pipe = load_pipeline( server, OnnxStableDiffusionImg2ImgPipeline, @@ -107,6 +110,7 @@ def run_img2img_pipeline( job.get_device(), params.lpw, params.inversion, + loras, ) progress = job.get_progress_callback() if params.lpw: