1
0
Fork 0

test loading UNet and encoder with LoRAs

This commit is contained in:
Sean Sube 2023-03-14 21:27:23 -05:00
parent 2a8c85cb3e
commit 56a4519818
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 53 additions and 21 deletions

View File

@ -205,27 +205,15 @@ def merge_lora(
else:
logger.info("could not find any nodes for %s", base_key)
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
logger.info(
"node counts: %s -> %s, %s -> %s",
len(fixed_initializer_names),
len(base_model.graph.initializer),
len(fixed_node_names),
len(base_model.graph.node)
)
if dest_path is None or dest_path == "" or dest_path == "ort":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(base_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
logger.info("successfully loaded model: %s", [i.name for i in sess.get_inputs()])
else:
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}.pb")
bare_model = write_external_data_tensors(base_model, dest_path)
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved model: %s", dest_file)
return base_model
if __name__ == "__main__":
@ -238,4 +226,24 @@ if __name__ == "__main__":
args = parser.parse_args()
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
blend_model = merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
if args.dest is None or args.dest == "" or args.dest == "ort":
# convert to external data and save to memory
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
logger.info("saved external data for %s nodes", len(external_data))
external_names, external_values = zip(*external_data)
opts = SessionOptions()
opts.add_external_initializers(list(external_names), list(external_values))
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()])
else:
convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb")
bare_model = write_external_data_tensors(blend_model, args.path)
dest_file = path.join(args.path, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString())
logger.info("successfully saved blended model: %s", dest_file)

View File

@ -21,6 +21,7 @@ from diffusers import (
PNDMScheduler,
StableDiffusionPipeline,
)
from onnxruntime import SessionOptions
from transformers import CLIPTokenizer
from onnx_web.diffusers.utils import expand_prompt
@ -35,6 +36,7 @@ try:
except ImportError:
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
from ..params import DeviceParams, Size
from ..server import ServerContext
from ..utils import run_gc
@ -220,6 +222,28 @@ def load_pipeline(
path.join(inversion, "tokenizer"),
)
# test LoRA blending
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in [
"arch",
"glass",
]]
logger.info("blending text encoder with LoRA models: %s", lora_models)
blended_text_encoder = merge_lora("text_encoder", lora_models, None, "text_encoder")
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(text_encoder_model, sess_options=text_encoder_opts)
logger.info("blending unet with LoRA models: %s", lora_models)
blended_unet = merge_lora("unet", lora_models, None, "unet")
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions()
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
components["unet"] = OnnxRuntimeModel.from_pretrained(unet_model, sess_options=unet_opts)
pipe = pipeline.from_pretrained(
model,
custom_pipeline=custom_pipeline,