feat(api): blend ONNX models in memory
This commit is contained in:
parent
4c17edb267
commit
2a8c85cb3e
|
@ -1,16 +1,19 @@
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from typing import Dict, List, Literal, Tuple
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Dict, Literal
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from onnx import TensorProto, load, numpy_helper
|
from onnx import ModelProto, load, numpy_helper
|
||||||
from onnx.checker import check_model
|
from onnx.checker import check_model
|
||||||
from onnx.external_data_helper import (
|
from onnx.external_data_helper import (
|
||||||
|
ExternalDataInfo,
|
||||||
convert_model_to_external_data,
|
convert_model_to_external_data,
|
||||||
|
set_external_data,
|
||||||
write_external_data_tensors,
|
write_external_data_tensors,
|
||||||
)
|
)
|
||||||
|
from onnxruntime import OrtValue, InferenceSession, SessionOptions
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from onnx_web.convert.utils import ConversionContext
|
from onnx_web.convert.utils import ConversionContext
|
||||||
|
@ -23,6 +26,24 @@ logger = getLogger(__name__)
|
||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
|
def buffer_external_data_tensors(model: ModelProto) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
|
||||||
|
external_data = []
|
||||||
|
for tensor in model.graph.initializer:
|
||||||
|
name = tensor.name
|
||||||
|
|
||||||
|
logger.info("externalizing tensor: %s", name)
|
||||||
|
if tensor.HasField("raw_data"):
|
||||||
|
npt = numpy_helper.to_array(tensor)
|
||||||
|
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
|
external_data.append((name, orv))
|
||||||
|
# mimic set_external_data
|
||||||
|
set_external_data(tensor, location="foo.bin")
|
||||||
|
tensor.name = name
|
||||||
|
tensor.ClearField("raw_data")
|
||||||
|
|
||||||
|
return (model, external_data)
|
||||||
|
|
||||||
|
|
||||||
def fix_initializer_name(key: str):
|
def fix_initializer_name(key: str):
|
||||||
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
|
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
|
||||||
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
|
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
|
||||||
|
@ -186,23 +207,25 @@ def merge_lora(
|
||||||
|
|
||||||
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
|
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
|
||||||
|
|
||||||
# save it back to disk
|
if dest_path is None or dest_path == "" or dest_path == "ort":
|
||||||
# TODO: save to memory instead
|
# convert to external data and save to memory
|
||||||
convert_model_to_external_data(
|
(bare_model, external_data) = buffer_external_data_tensors(base_model)
|
||||||
base_model,
|
logger.info("saved external data for %s nodes", len(external_data))
|
||||||
all_tensors_to_one_file=True,
|
|
||||||
location=f"lora-{dest_type}-external.pb",
|
|
||||||
)
|
|
||||||
bare_model = write_external_data_tensors(base_model, dest_path)
|
|
||||||
|
|
||||||
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
|
external_names, external_values = zip(*external_data)
|
||||||
with open(dest_file, "wb") as model_file:
|
opts = SessionOptions()
|
||||||
model_file.write(bare_model.SerializeToString())
|
opts.add_external_initializers(list(external_names), list(external_values))
|
||||||
|
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
|
||||||
|
logger.info("successfully loaded model: %s", [i.name for i in sess.get_inputs()])
|
||||||
|
else:
|
||||||
|
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}.pb")
|
||||||
|
bare_model = write_external_data_tensors(base_model, dest_path)
|
||||||
|
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
|
||||||
|
|
||||||
logger.info("model saved, checking...")
|
with open(dest_file, "w+b") as model_file:
|
||||||
check_model(dest_file)
|
model_file.write(bare_model.SerializeToString())
|
||||||
|
|
||||||
logger.info("model successfully exported")
|
logger.info("successfully saved model: %s", dest_file)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue