remove extra param, correct output path
This commit is contained in:
parent
8cf6f2215d
commit
91210ee236
|
@ -238,8 +238,8 @@ if __name__ == "__main__":
|
||||||
logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()])
|
logger.info("successfully loaded blended model: %s", [i.name for i in sess.get_inputs()])
|
||||||
else:
|
else:
|
||||||
convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb")
|
convert_model_to_external_data(blend_model, all_tensors_to_one_file=True, location=f"lora-{args.type}.pb")
|
||||||
bare_model = write_external_data_tensors(blend_model, args.path)
|
bare_model = write_external_data_tensors(blend_model, args.dest)
|
||||||
dest_file = path.join(args.path, f"lora-{args.type}.onnx")
|
dest_file = path.join(args.dest, f"lora-{args.type}.onnx")
|
||||||
|
|
||||||
with open(dest_file, "w+b") as model_file:
|
with open(dest_file, "w+b") as model_file:
|
||||||
model_file.write(bare_model.SerializeToString())
|
model_file.write(bare_model.SerializeToString())
|
||||||
|
|
|
@ -248,7 +248,7 @@ def load_pipeline(
|
||||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
||||||
# blend and load text encoder
|
# blend and load text encoder
|
||||||
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, None, "text_encoder")
|
blended_text_encoder = merge_lora(path.join(model, "text_encoder", "model.onnx"), lora_models, "text_encoder")
|
||||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
text_encoder_opts = SessionOptions()
|
text_encoder_opts = SessionOptions()
|
||||||
|
@ -262,7 +262,7 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# blend and load unet
|
# blend and load unet
|
||||||
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, None, "unet")
|
blended_unet = merge_lora(path.join(model, "unet", "model.onnx"), lora_models, "unet")
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = SessionOptions()
|
unet_opts = SessionOptions()
|
||||||
|
|
Loading…
Reference in New Issue