diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 7a643377..7add40e6 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from logging import getLogger -from typing import Dict, List, Literal, Tuple from os import path +from typing import Dict, List, Literal, Tuple import numpy as np import torch @@ -12,7 +12,7 @@ from onnx.external_data_helper import ( set_external_data, write_external_data_tensors, ) -from onnxruntime import OrtValue, InferenceSession, SessionOptions +from onnxruntime import InferenceSession, OrtValue, SessionOptions from safetensors.torch import load_file from onnx_web.convert.utils import ConversionContext @@ -25,7 +25,9 @@ logger = getLogger(__name__) ### -def buffer_external_data_tensors(model: ModelProto) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]: +def buffer_external_data_tensors( + model: ModelProto, +) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]: external_data = [] for tensor in model.graph.initializer: name = tensor.name @@ -74,17 +76,19 @@ def merge_lora( lora_prefix = f"lora_{dest_type}_" blended: Dict[str, np.ndarray] = {} - for lora_name, lora_model, lora_weight in zip(lora_names, lora_models, lora_weights): + for lora_name, lora_model, lora_weight in zip( + lora_names, lora_models, lora_weights + ): logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) for key in lora_model.keys(): if ".lora_down" in key and lora_prefix in key: - base_key = key[: key.index(".lora_down")].replace( - lora_prefix, "" - ) + base_key = key[: key.index(".lora_down")].replace(lora_prefix, "") up_key = key.replace("lora_down", "lora_up") alpha_key = key[: key.index("lora_down")] + "alpha" - logger.info("blending weights for keys: %s, %s, %s", key, up_key, alpha_key) + logger.info( + "blending weights for keys: %s, %s, %s", key, up_key, alpha_key + ) down_weight = lora_model[key].to(dtype=torch.float32) up_weight = lora_model[up_key].to(dtype=torch.float32) @@ -95,12 +99,22 @@ def merge_lora( try: if len(up_weight.size()) == 2: # blend for nn.Linear - logger.info("blending weights for Linear node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha) + logger.info( + "blending weights for Linear node: %s, %s, %s", + down_weight.shape, + up_weight.shape, + alpha, + ) weights = up_weight @ down_weight - np_weights = (weights.numpy() * (alpha / dim)) + np_weights = weights.numpy() * (alpha / dim) elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1): # blend for nn.Conv2d 1x1 - logger.info("blending weights for Conv node: %s, %s, %s", down_weight.shape, up_weight.shape, alpha) + logger.info( + "blending weights for Conv node: %s, %s, %s", + down_weight.shape, + up_weight.shape, + alpha, + ) weights = ( ( up_weight.squeeze(3).squeeze(2) @@ -109,10 +123,14 @@ def merge_lora( .unsqueeze(2) .unsqueeze(3) ) - np_weights = (weights.numpy() * (alpha / dim)) + np_weights = weights.numpy() * (alpha / dim) else: # TODO: add support for Conv2d 3x3 - logger.warning("unknown LoRA node type at %s: %s", base_key, up_weight.shape[-2:]) + logger.warning( + "unknown LoRA node type at %s: %s", + base_key, + up_weight.shape[-2:], + ) continue np_weights *= lora_weight @@ -122,15 +140,13 @@ def merge_lora( blended[base_key] = np_weights except Exception: - logger.exception( - "error blending weights for key %s", base_key - ) + logger.exception("error blending weights for key %s", base_key) logger.info( "updating %s of %s initializers: %s", len(blended.keys()), len(base_model.graph.initializer), - list(blended.keys()) + list(blended.keys()), ) fixed_initializer_names = [ @@ -138,17 +154,19 @@ def merge_lora( ] # logger.info("fixed initializer names: %s", fixed_initializer_names) - fixed_node_names = [ - fix_node_name(node.name) for node in base_model.graph.node - ] + fixed_node_names = [fix_node_name(node.name) for node in base_model.graph.node] # logger.info("fixed node names: %s", fixed_node_names) - for base_key, weights in blended.items(): conv_key = base_key + "_Conv" matmul_key = base_key + "_MatMul" - logger.info("key %s has conv: %s, matmul: %s", base_key, conv_key in fixed_node_names, matmul_key in fixed_node_names) + logger.info( + "key %s has conv: %s, matmul: %s", + base_key, + conv_key in fixed_node_names, + matmul_key in fixed_node_names, + ) if conv_key in fixed_node_names: conv_idx = fixed_node_names.index(conv_key) @@ -166,7 +184,11 @@ def merge_lora( # blending base_weights = numpy_helper.to_array(weight_node) - logger.info("found blended weights for conv: %s, %s", weights.shape, base_weights.shape) + logger.info( + "found blended weights for conv: %s, %s", + weights.shape, + base_weights.shape, + ) blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) blended = np.expand_dims(blended, (2, 3)) @@ -191,7 +213,11 @@ def merge_lora( # blending base_weights = numpy_helper.to_array(matmul_node) - logger.info("found blended weights for matmul: %s, %s", weights.shape, base_weights.shape) + logger.info( + "found blended weights for matmul: %s, %s", + weights.shape, + base_weights.shape, + ) blended = base_weights + weights.transpose() logger.info("blended weight shape: %s", blended.shape) @@ -208,7 +234,7 @@ def merge_lora( len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), - len(base_model.graph.node) + len(base_model.graph.node), ) return base_model @@ -219,11 +245,16 @@ if __name__ == "__main__": parser.add_argument("--base", type=str) parser.add_argument("--dest", type=str) parser.add_argument("--type", type=str, choices=["text_encoder", "unet"]) - parser.add_argument("--lora_models", nargs='+', type=str) - parser.add_argument("--lora_weights", nargs='+', type=float) + parser.add_argument("--lora_models", nargs="+", type=str) + parser.add_argument("--lora_weights", nargs="+", type=float) args = parser.parse_args() - logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights) + logger.info( + "merging %s with %s with weights: %s", + args.lora_models, + args.base, + args.lora_weights, + ) blend_model = merge_lora(args.base, args.lora_models, args.type, args.lora_weights) if args.dest is None or args.dest == "" or args.dest == "ort": @@ -234,10 +265,18 @@ if __name__ == "__main__": external_names, external_values = zip(*external_data) opts = SessionOptions() opts.add_external_initializers(list(external_names), list(external_values)) - sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"]) - logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]) + sess = InferenceSession( + bare_model.SerializeToString(), + sess_options=opts, + providers=["CPUExecutionProvider"], + ) + logger.info( + "successfully loaded blended model: %s", [i.name for i in sess.get_inputs()] + ) else: - convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb") + convert_model_to_external_data( + blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb" + ) bare_model = write_external_data_tensors(blend_model, args.dest) dest_file = path.join(args.dest, f"lora-{args.type}.onnx") diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 7ee12a06..5fe3feba 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -37,7 +37,7 @@ try: except ImportError: from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler -from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors +from ..convert.diffusion.lora import buffer_external_data_tensors, merge_lora from ..params import DeviceParams, Size from ..server import ServerContext from ..utils import run_gc @@ -118,7 +118,10 @@ def get_loras_from_prompt(prompt: str) -> Tuple[str, List[str]]: name, weight = next_match.groups() loras.append(name) # remove this match and look for another - remaining_prompt = remaining_prompt[:next_match.start()] + remaining_prompt[next_match.end():] + remaining_prompt = ( + remaining_prompt[: next_match.start()] + + remaining_prompt[next_match.end() :] + ) next_match = lora_expr.search(remaining_prompt) return (remaining_prompt, loras) @@ -244,15 +247,23 @@ def load_pipeline( ) # test LoRA blending - lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras] + lora_models = [ + path.join(server.model_path, "lora", f"{i}.safetensors") for i in loras + ] logger.info("blending base model %s with LoRA models: %s", model, lora_models) # blend and load text encoder - blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder") - (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) + blended_text_encoder = merge_lora( + path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder" + ) + (text_encoder_model, text_encoder_data) = buffer_external_data_tensors( + blended_text_encoder + ) text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_opts = SessionOptions() - text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) + text_encoder_opts.add_external_initializers( + list(text_encoder_names), list(text_encoder_values) + ) components["text_encoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( text_encoder_model.SerializeToString(), @@ -262,7 +273,9 @@ def load_pipeline( ) # blend and load unet - blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet") + blended_unet = merge_lora( + path.join(model, "unet", "model.onnx"), lora_models, "unet" + ) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) unet_names, unet_values = zip(*unet_data) unet_opts = SessionOptions()