load base models from model dir, use correct ORT provider
This commit is contained in:
parent
56a4519818
commit
ce05e76947
|
@ -229,20 +229,20 @@ def load_pipeline(
|
||||||
]]
|
]]
|
||||||
|
|
||||||
logger.info("blending text encoder with LoRA models: %s", lora_models)
|
logger.info("blending text encoder with LoRA models: %s", lora_models)
|
||||||
blended_text_encoder = merge_lora("text_encoder", lora_models, None, "text_encoder")
|
blended_text_encoder = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/text_encoder/model.onnx"), lora_models, None, "text_encoder")
|
||||||
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
(text_encoder_model, text_encoder_data) = buffer_external_data_tensors(blended_text_encoder)
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
text_encoder_opts = SessionOptions()
|
text_encoder_opts = SessionOptions()
|
||||||
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
text_encoder_opts.add_external_initializers(list(text_encoder_names), list(text_encoder_values))
|
||||||
components["text_encoder"] = OnnxRuntimeModel.from_pretrained(text_encoder_model, sess_options=text_encoder_opts)
|
components["text_encoder"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(text_encoder_model.SerializeToString(), provider=device.ort_provider(), sess_options=text_encoder_opts))
|
||||||
|
|
||||||
logger.info("blending unet with LoRA models: %s", lora_models)
|
logger.info("blending unet with LoRA models: %s", lora_models)
|
||||||
blended_unet = merge_lora("unet", lora_models, None, "unet")
|
blended_unet = merge_lora(path.join(server.model_path, "stable-diffusion-onnx-v1-5/unet/model.onnx"), lora_models, None, "unet")
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = SessionOptions()
|
unet_opts = SessionOptions()
|
||||||
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
|
||||||
components["unet"] = OnnxRuntimeModel.from_pretrained(unet_model, sess_options=unet_opts)
|
components["unet"] = OnnxRuntimeModel(OnnxRuntimeModel.load_model(unet_model.SerializeToString(), provider=device.ort_provider(), sess_options=unet_opts))
|
||||||
|
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
|
|
Loading…
Reference in New Issue