test loading UNet and encoder with LoRAs
This commit is contained in:
parent
2a8c85cb3e
commit
56a4519818
|
@ -205,27 +205,15 @@ def merge_lora(
|
||||||
else:
|
else:
|
||||||
logger.info("could not find any nodes for %s", base_key)
|
logger.info("could not find any nodes for %s", base_key)
|
||||||
|
|
||||||
logger.info("node counts: %s -> %s, %s -> %s", len(fixed_initializer_names), len(base_model.graph.initializer), len(fixed_node_names), len(base_model.graph.node))
|
logger.info(
|
||||||
|
"node counts: %s -> %s, %s -> %s",
|
||||||
|
len(fixed_initializer_names),
|
||||||
|
len(base_model.graph.initializer),
|
||||||
|
len(fixed_node_names),
|
||||||
|
len(base_model.graph.node)
|
||||||
|
)
|
||||||
|
|
||||||
if dest_path is None or dest_path == "" or dest_path == "ort":
|
return base_model
|
||||||
# convert to external data and save to memory
|
|
||||||
(bare_model, external_data) = buffer_external_data_tensors(base_model)
|
|
||||||
logger.info("saved external data for %s nodes", len(external_data))
|
|
||||||
|
|
||||||
external_names, external_values = zip(*external_data)
|
|
||||||
opts = SessionOptions()
|
|
||||||
opts.add_external_initializers(list(external_names), list(external_values))
|
|
||||||
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
|
|
||||||
logger.info("successfully loaded model: %s", [i.name for i in sess.get_inputs()])
|
|
||||||
else:
|
|
||||||
convert_model_to_external_data(base_model, all_tensors_to_one_file=True, location=f"lora-{dest_type}.pb")
|
|
||||||
bare_model = write_external_data_tensors(base_model, dest_path)
|
|
||||||
dest_file = path.join(dest_path, f"lora-{dest_type}.onnx")
|
|
||||||
|
|
||||||
with open(dest_file, "w+b") as model_file:
|
|
||||||
model_file.write(bare_model.SerializeToString())
|
|
||||||
|
|
||||||
logger.info("successfully saved model: %s", dest_file)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -238,4 +226,24 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
|
logger.info("merging %s with %s with weights: %s", args.lora_models, args.base, args.lora_weights)
|
||||||
merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
|
|
||||||
|
blend_model = merge_lora(args.base, args.lora_models, args.dest, args.type, args.lora_weights)
|
||||||
|
if args.dest is None or args.dest == "" or args.dest == "ort":
|
||||||
|
# convert to external data and save to memory
|
||||||
|
(bare_model, external_data) = buffer_external_data_tensors(blend_model)
|
||||||
|
logger.info("saved external data for %s nodes", len(external_data))
|
||||||
|
|
||||||
|
external_names, external_values = zip(*external_data)
|
||||||
|
opts = SessionOptions()
|
||||||
|
opts.add_external_initializers(list(external_names), list(external_values))
|
||||||
|
sess = InferenceSession(bare_model.SerializeToString(), sess_options=opts, providers=["CPUExecutionProvider"])
|
||||||
|
logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()])
|
||||||
|
else:
|
||||||
|
convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb")
|
||||||
|
bare_model = write_external_data_tensors(blend_model, args.path)
|
||||||
|
dest_file = path.join(args.path, f"lora-{args.type}.onnx")
|
||||||
|
|
||||||
|
with open(dest_file, "w+b") as model_file:
|
||||||
|
model_file.write(bare_model.SerializeToString())
|
||||||
|
|
||||||
|
logger.info("successfully saved blended model: %s", dest_file)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from diffusers import (
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
|
from onnxruntime import SessionOptions
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from onnx_web.diffusers.utils import expand_prompt
|
from onnx_web.diffusers.utils import expand_prompt
|
||||||
|
@ -35,6 +36,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
from ..diffusers.stub_scheduler import StubScheduler as UniPCMultistepScheduler
|
||||||
|
|
||||||
|
from ..convert.diffusion.lora import merge_lora, buffer_external_data_tensors
|
||||||
from ..params import DeviceParams, Size
|
from ..params import DeviceParams, Size
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
@ -220,6 +222,28 @@ def load_pipeline(
|
||||||
path.join(inversion, "tokenizer"),
|
path.join(inversion, "tokenizer"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# test LoRA blending
|
||||||
|
lora_models = [path.join(server.model_path, "lora", f"{i}.safetensors") for i in [
|
||||||
|
"arch",
|
||||||
|
"glass",
|
||||||
|
]]
|
||||||
|
|
||||||
|
logger.info("blending text encoder with LoRA models: %s", lora_models)
|
||||||
|
blended_text_encoder = merge_lora("text_encoder", lora_models, None, "text_encoder")
|
||||||
|
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||||
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
|
text_encoder_opts = SessionOptions()
|
||||||
|
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
||||||
|
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(text_encoder_model, sess_options=text_encoder_opts)
|
||||||
|
|
||||||
|
logger.info("blending unet with LoRA models: %s", lora_models)
|
||||||
|
blended_unet = merge_lora("unet", lora_models, None, "unet")
|
||||||
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
|
unet_names, unet_values = zip(*unet_data)
|
||||||
|
unet_opts = SessionOptions()
|
||||||
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||||
|
components["unet"] = OnnxRuntimeModel.from_pretrained(unet_model, sess_options=unet_opts)
|
||||||
|
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
custom_pipeline=custom_pipeline,
|
custom_pipeline=custom_pipeline,
|
||||||
|
|
Loading…
Reference in New Issue