feat(api): blend ONNX models in memory
This commit is contained in:
parent
4c17edb267
commit
2a8c85cb3e
|
@ -1,16 +1,19 @@
|
|||
from argparse import ArgumentParser
|
||||
from logging import getLogger
|
||||
from typing import Dict, List, Literal, Tuple
|
||||
from os import path
|
||||
from typing import Dict, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnx import TensorProto, load, numpy_helper
|
||||
from onnx import ModelProto, load, numpy_helper
|
||||
from onnx.checker import check_model
|
||||
from onnx.external_data_helper import (
|
||||
ExternalDataInfo,
|
||||
convert_model_to_external_data,
|
||||
set_external_data,
|
||||
write_external_data_tensors,
|
||||
)
|
||||
from onnxruntime import OrtValue, InferenceSession, SessionOptions
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from onnx_web.convert.utils import ConversionContext
|
||||
|
@ -23,6 +26,24 @@ logger = getLogger(__name__)
|
|||
###
|
||||
|
||||
|
||||
def buffer_external_data_tensors(model: ModelProto) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
|
||||
external_data = []
|
||||
for tensor in model.graph.initializer:
|
||||
name = tensor.name
|
||||
|
||||
logger.info("externalizing tensor: %s", name)
|
||||
if tensor.HasField("raw_data"):
|
||||
npt = numpy_helper.to_array(tensor)
|
||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
external_data.append((name, orv))
|
||||
# mimic set_external_data
|
||||
set_external_data(tensor, location="foo.bin")
|
||||
tensor.name = name
|
||||
tensor.ClearField("raw_data")
|
||||
|
||||
return (model, external_data)
|
||||
|
||||
|
||||
def fix_initializer_name(key: str):
|
||||
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
|
||||
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
|
||||
|
@ -186,23 +207,25 @@ def merge_lora(
|
|||
|
||||
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
|
||||
|
||||
# save it back to disk
|
||||
# TODO: save to memory instead
|
||||
convert_model_to_external_data(
|
||||
base_model,
|
||||
all_tensors_to_one_file=True,
|
||||
location=f"lora-{dest_type}-external.pb",
|
||||
)
|
||||
bare_model = write_external_data_tensors(base_model, dest_path)
|
||||
if dest_path is None or dest_path == "" or dest_path == "ort":
|
||||
# convert to external data and save to memory
|
||||
(bare_model, external_data) = buffer_external_data_tensors(base_model)
|
||||
logger.info("saved external data for %s nodes", len(external_data))
|
||||
|
||||
external_names, external_values = zip(*external_data)
|
||||
opts = SessionOptions()
|
||||
opts.add_external_initializers(list(external_names), list(external_values))
|
||||
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
|
||||
logger.info("successfully loaded model: %s", [i.name for i in sess.get_inputs()])
|
||||
else:
|
||||
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}.pb")
|
||||
bare_model = write_external_data_tensors(base_model, dest_path)
|
||||
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
|
||||
with open(dest_file, "wb") as model_file:
|
||||
|
||||
with open(dest_file, "w+b") as model_file:
|
||||
model_file.write(bare_model.SerializeToString())
|
||||
|
||||
logger.info("model saved, checking...")
|
||||
check_model(dest_file)
|
||||
|
||||
logger.info("model successfully exported")
|
||||
logger.info("successfully saved model: %s", dest_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue