1
0
Fork 0

remove extra param, correct output path

This commit is contained in:
Sean Sube 2023-03-14 23:32:18 -05:00
parent 8cf6f2215d
commit 91210ee236
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 4 additions and 4 deletions

View File

@ -238,8 +238,8 @@ if __name__ == "__main__":
logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()]) logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()])
else: else:
convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb") convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb")
bare_model = write_external_data_tensors(blend_model, args.path) bare_model = write_external_data_tensors(blend_model, args.dest)
dest_file = path.join(args.path, f"lora-{args.type}.onnx") dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
with open(dest_file, "w+b") as model_file: with open(dest_file, "w+b") as model_file:
model_file.write(bare_model.SerializeToString()) model_file.write(bare_model.SerializeToString())

View File

@ -248,7 +248,7 @@ def load_pipeline(
logger.info("blending base model %s with LoRA models: %s", model, lora_models) logger.info("blending base model %s with LoRA models: %s", model, lora_models)
# blend and load text encoder # blend and load text encoder
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, None, "text_encoder") blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder")
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions() text_encoder_opts = SessionOptions()
@ -262,7 +262,7 @@ def load_pipeline(
) )
# blend and load unet # blend and load unet
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, None, "unet") blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet")
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions() unet_opts = SessionOptions()