1
0
Fork 0

load base models from model dir, use correct ORT provider

This commit is contained in:
Sean Sube 2023-03-14 21:57:37 -05:00
parent 56a4519818
commit ce05e76947
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 4 additions and 4 deletions

View File

@ -229,20 +229,20 @@ def load_pipeline(
]] ]]
logger.info("blending text encoder with LoRA models: %s", lora_models) logger.info("blending text encoder with LoRA models: %s", lora_models)
blended_text_encoder = merge_lora("text_encoder", lora_models, None, "text_encoder") blended_text_encoder = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), lora_models, None, "text_encoder")
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder) (text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data) text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = SessionOptions() text_encoder_opts = SessionOptions()
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values)) text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(text_encoder_model, sess_options=text_encoder_opts) components["text_encoder"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(text_encoder_model.SerializeToString(), provider=device.ort_provider(), sess_options=text_encoder_opts))
logger.info("blending unet with LoRA models: %s", lora_models) logger.info("blending unet with LoRA models: %s", lora_models)
blended_unet = merge_lora("unet", lora_models, None, "unet") blended_unet = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/unet/model.onnx"), lora_models, None, "unet")
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet) (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data) unet_names, unet_values = zip(*unet_data)
unet_opts = SessionOptions() unet_opts = SessionOptions()
unet_opts.add_external_initializers(list(unet_names), list(unet_values)) unet_opts.add_external_initializers(list(unet_names), list(unet_values))
components["unet"] = OnnxRuntimeModel.from_pretrained(unet_model, sess_options=unet_opts) components["unet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(unet_model.SerializeToString(), provider=device.ort_provider(), sess_options=unet_opts))
pipe = pipeline.from_pretrained( pipe = pipeline.from_pretrained(
model, model,