test loading UNet and encoder with LoRAs
This commit is contained in:
parent
2a8c85cb3e
commit
56a4519818
|
@ -205,27 +205,15 @@ def merge_lora(
|
|||
else:
|
||||
logger.info("could not find any nodes for %s", base_key)
|
||||
|
||||
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
|
||||
logger.info(
|
||||
"node counts: %s -> %s, %s -> %s",
|
||||
len(fixed_initializer_names),
|
||||
len(base_model.graph.initializer),
|
||||
len(fixed_node_names),
|
||||
len(base_model.graph.node)
|
||||
)
|
||||
|
||||
if dest_path is None or dest_path == "" or dest_path == "ort":
|
||||
# convert to external data and save to memory
|
||||
(bare_model, external_data) = buffer_external_data_tensors(base_model)
|
||||
logger.info("saved external data for %s nodes", len(external_data))
|
||||
|
||||
external_names, external_values = zip(*external_data)
|
||||
opts = SessionOptions()
|
||||
opts.add_external_initializers(list(external_names), list(external_values))
|
||||
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
|
||||
logger.info("successfully loaded model: %s", [i.name for i in sess.get_inputs()])
|
||||
else:
|
||||
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}.pb")
|
||||
bare_model = write_external_data_tensors(base_model, dest_path)
|
||||
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
|
||||
|
||||
with open(dest_file, "w+b") as model_file:
|
||||
model_file.write(bare_model.SerializeToString())
|
||||
|
||||
logger.info("successfully saved model: %s", dest_file)
|
||||
return base_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -238,4 +226,24 @@ if __name__ == "__main__":
|
|||
|
||||
args = parser.parse_args()
|
||||
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
|
||||
merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
|
||||
|
||||
blend_model = merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
|
||||
if args.dest is None or args.dest == "" or args.dest == "ort":
|
||||
# convert to external data and save to memory
|
||||
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||
logger.info("saved external data for %s nodes", len(external_data))
|
||||
|
||||
external_names, external_values = zip(*external_data)
|
||||
opts = SessionOptions()
|
||||
opts.add_external_initializers(list(external_names), list(external_values))
|
||||
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
|
||||
logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()])
|
||||
else:
|
||||
convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb")
|
||||
bare_model = write_external_data_tensors(blend_model, args.path)
|
||||
dest_file = path.join(args.path, f"lora-{args.type}.onnx")
|
||||
|
||||
with open(dest_file, "w+b") as model_file:
|
||||
model_file.write(bare_model.SerializeToString())
|
||||
|
||||
logger.info("successfully saved blended model: %s", dest_file)
|
||||
|
|
|
@ -21,6 +21,7 @@ from diffusers import (
|
|||
PNDMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from onnxruntime import SessionOptions
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from onnx_web.diffusers.utils import expand_prompt
|
||||
|
@ -35,6 +36,7 @@ try:
|
|||
except ImportError:
|
||||
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
||||
|
||||
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
|
||||
from ..params import DeviceParams, Size
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
|
@ -220,6 +222,28 @@ def load_pipeline(
|
|||
path.join(inversion, "tokenizer"),
|
||||
)
|
||||
|
||||
# test LoRA blending
|
||||
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in [
|
||||
"arch",
|
||||
"glass",
|
||||
]]
|
||||
|
||||
logger.info("blending text encoder with LoRA models: %s", lora_models)
|
||||
blended_text_encoder = merge_lora("text_encoder", lora_models, None, "text_encoder")
|
||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||
text_encoder_opts = SessionOptions()
|
||||
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
||||
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(text_encoder_model, sess_options=text_encoder_opts)
|
||||
|
||||
logger.info("blending unet with LoRA models: %s", lora_models)
|
||||
blended_unet = merge_lora("unet", lora_models, None, "unet")
|
||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||
unet_names, unet_values = zip(*unet_data)
|
||||
unet_opts = SessionOptions()
|
||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||
components["unet"] = OnnxRuntimeModel.from_pretrained(unet_model, sess_options=unet_opts)
|
||||
|
||||
pipe = pipeline.from_pretrained(
|
||||
model,
|
||||
custom_pipeline=custom_pipeline,
|
||||
|
|
Loading…
Reference in New Issue