diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 7a1d1b54..4f5f9ab4 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -205,27 +205,15 @@ def merge_lora( else: logger.info("could not find any nodes for %s", base_key) - logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node)) + logger.info( + "node counts: %s -> %s, %s -> %s", + len(fixed_initializer_names), + len(base_model.graph.initializer), + len(fixed_node_names), + len(base_model.graph.node) + ) - if dest_path is None or dest_path == "" or dest_path == "ort": - # convert to external data and save to memory - (bare_model, external_data) = buffer_external_data_tensors(base_model) - logger.info("saved external data for %s nodes", len(external_data)) - - external_names, external_values = zip(*external_data) - opts = SessionOptions() - opts.add_external_initializers(list(external_names), list(external_values)) - sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"]) - logger.info("successfully loaded model: %s", [i.name for i in sess.get_inputs()]) - else: - convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}.pb") - bare_model = write_external_data_tensors(base_model, dest_path) - dest_file = path.join(dest_path, f"lora-{dest_type}.onnx") - - with open(dest_file, "w+b") as model_file: - model_file.write(bare_model.SerializeToString()) - - logger.info("successfully saved model: %s", dest_file) + return base_model if __name__ == "__main__": @@ -238,4 +226,24 @@ if __name__ == "__main__": args = parser.parse_args() logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights) - merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights) + + blend_model = merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights) + if args.dest is None or args.dest == "" or args.dest == "ort": + # convert to external data and save to memory + (bare_model, external_data) = buffer_external_data_tensors(blend_model) + logger.info("saved external data for %s nodes", len(external_data)) + + external_names, external_values = zip(*external_data) + opts = SessionOptions() + opts.add_external_initializers(list(external_names), list(external_values)) + sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"]) + logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]) + else: + convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb") + bare_model = write_external_data_tensors(blend_model, args.path) + dest_file = path.join(args.path, f"lora-{args.type}.onnx") + + with open(dest_file, "w+b") as model_file: + model_file.write(bare_model.SerializeToString()) + + logger.info("successfully saved blended model: %s", dest_file) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index b616fba4..340be21d 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -21,6 +21,7 @@ from diffusers import ( PNDMScheduler, StableDiffusionPipeline, ) +from onnxruntime import SessionOptions from transformers import CLIPTokenizer from onnx_web.diffusers.utils import expand_prompt @@ -35,6 +36,7 @@ try: except ImportError: from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler +from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors from ..params import DeviceParams, Size from ..server import ServerContext from ..utils import run_gc @@ -220,6 +222,28 @@ def load_pipeline( path.join(inversion, "tokenizer"), ) + # test LoRA blending + lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in [ + "arch", + "glass", + ]] + + logger.info("blending text encoder with LoRA models: %s", lora_models) + blended_text_encoder = merge_lora("text_encoder", lora_models, None, "text_encoder") + (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) + text_encoder_names, text_encoder_values = zip(*text_encoder_data) + text_encoder_opts = SessionOptions() + text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) + components["text_encoder"] = OnnxRuntimeModel.from_pretrained(text_encoder_model, sess_options=text_encoder_opts) + + logger.info("blending unet with LoRA models: %s", lora_models) + blended_unet = merge_lora("unet", lora_models, None, "unet") + (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) + unet_names, unet_values = zip(*unet_data) + unet_opts = SessionOptions() + unet_opts.add_external_initializers(list(unet_names), list(unet_values)) + components["unet"] = OnnxRuntimeModel.from_pretrained(unet_model, sess_options=unet_opts) + pipe = pipeline.from_pretrained( model, custom_pipeline=custom_pipeline,