diff --git a/api/onnx_web/convert/diffusion/lora.py b/api/onnx_web/convert/diffusion/lora.py index 69af4db3..7a1d1b54 100644 --- a/api/onnx_web/convert/diffusion/lora.py +++ b/api/onnx_web/convert/diffusion/lora.py @@ -1,16 +1,19 @@ from argparse import ArgumentParser from logging import getLogger +from typing import Dict, List, Literal, Tuple from os import path -from typing import Dict, Literal import numpy as np import torch -from onnx import TensorProto, load, numpy_helper +from onnx import ModelProto, load, numpy_helper from onnx.checker import check_model from onnx.external_data_helper import ( + ExternalDataInfo, convert_model_to_external_data, + set_external_data, write_external_data_tensors, ) +from onnxruntime import OrtValue, InferenceSession, SessionOptions from safetensors.torch import load_file from onnx_web.convert.utils import ConversionContext @@ -23,6 +26,24 @@ logger = getLogger(__name__) ### +def buffer_external_data_tensors(model: ModelProto) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]: + external_data = [] + for tensor in model.graph.initializer: + name = tensor.name + + logger.info("externalizing tensor: %s", name) + if tensor.HasField("raw_data"): + npt = numpy_helper.to_array(tensor) + orv = OrtValue.ortvalue_from_numpy(npt) + external_data.append((name, orv)) + # mimic set_external_data + set_external_data(tensor, location="foo.bin") + tensor.name = name + tensor.ClearField("raw_data") + + return (model, external_data) + + def fix_initializer_name(key: str): # lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight # lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0 @@ -186,23 +207,25 @@ def merge_lora( logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node)) - # save it back to disk - # TODO: save to memory instead - convert_model_to_external_data( - base_model, - all_tensors_to_one_file=True, - location=f"lora-{dest_type}-external.pb", - ) - bare_model = write_external_data_tensors(base_model, dest_path) + if dest_path is None or dest_path == "" or dest_path == "ort": + # convert to external data and save to memory + (bare_model, external_data) = buffer_external_data_tensors(base_model) + logger.info("saved external data for %s nodes", len(external_data)) - dest_file = path.join(dest_path, f"lora-{dest_type}.onnx") - with open(dest_file, "wb") as model_file: - model_file.write(bare_model.SerializeToString()) + external_names, external_values = zip(*external_data) + opts = SessionOptions() + opts.add_external_initializers(list(external_names), list(external_values)) + sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"]) + logger.info("successfully loaded model: %s", [i.name for i in sess.get_inputs()]) + else: + convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}.pb") + bare_model = write_external_data_tensors(base_model, dest_path) + dest_file = path.join(dest_path, f"lora-{dest_type}.onnx") - logger.info("model saved, checking...") - check_model(dest_file) + with open(dest_file, "w+b") as model_file: + model_file.write(bare_model.SerializeToString()) - logger.info("model successfully exported") + logger.info("successfully saved model: %s", dest_file) if __name__ == "__main__":