1
0
Fork 0

feat(api): blend ONNX models in memory

This commit is contained in:
Sean Sube 2023-03-14 19:38:27 -05:00
parent 4c17edb267
commit 2a8c85cb3e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 39 additions and 16 deletions

View File

@ -1,16 +1,19 @@
from argparse import ArgumentParser
from logging import getLogger
from typing import Dict, List, Literal, Tuple
from os import path
from typing import Dict, Literal
import numpy as np
import torch
from onnx import TensorProto, load, numpy_helper
from onnx import ModelProto, load, numpy_helper
from onnx.checker import check_model
from onnx.external_data_helper import (
ExternalDataInfo,
convert_model_to_external_data,
set_external_data,
write_external_data_tensors,
)
from onnxruntime import OrtValue, InferenceSession, SessionOptions
from safetensors.torch import load_file
from onnx_web.convert.utils import ConversionContext
@ -23,6 +26,24 @@ logger = getLogger(__name__)
###
def buffer_external_data_tensors(model: ModelProto) -> Tuple[ModelProto, List[Tuple[str, OrtValue]]]:
external_data = []
for tensor in model.graph.initializer:
name = tensor.name
logger.info("externalizing tensor: %s", name)
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
external_data.append((name, orv))
# mimic set_external_data
set_external_data(tensor, location="foo.bin")
tensor.name = name
tensor.ClearField("raw_data")
return (model, external_data)
def fix_initializer_name(key: str):
# lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_attn2_to_out_0.lora_down.weight
# lora, unet, up_block.3.attentions.2.transformer_blocks.0.attn2.to_out.0
@ -186,23 +207,25 @@ def merge_lora(
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
# save it back to disk
# TODO: save to memory instead
convert_model_to_external_data(
base_model,
all_tensors_to_one_file=True,
location=f"lora-{dest_type}-external.pb",
)
bare_model = write_external_data_tensors(base_model, dest_path)
if dest_path is None or dest_path == "" or dest_path == "ort":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(base_model)
logger.info("saved external data for %s nodes", len(external_data))
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
with open(dest_file, "wb") as model_file:
model_file.write(bare_model.SerializeToString())
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
logger.info("successfully loaded model: %s", [i.name for i in sess.get_inputs()])
else:
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}.pb")
bare_model = write_external_data_tensors(base_model, dest_path)
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
logger.info("model saved, checking...")
check_model(dest_file)
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("model successfully exported")
logger.info("successfully saved model: %s", dest_file)
if __name__ == "__main__":