1
0
Fork 0

load UNet with ControlNet weights as needed

This commit is contained in:
Sean Sube 2023-04-13 23:05:00 -05:00
parent 06d55b0f1f
commit c2f8ce5814
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 16 additions and 11 deletions

View File

@ -199,7 +199,21 @@ def load_pipeline(
) )
} }
# shared components
text_encoder = None text_encoder = None
unet = None
# ControlNet component
if pipeline == "controlnet" and control is not None:
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
path.join(server.model_path, "control", f"{control.name}.onnx"),
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
)
unet = path.join(model, "cnet", ONNX_MODEL)
# Textual Inversion blending # Textual Inversion blending
if inversions is not None and len(inversions) > 0: if inversions is not None and len(inversions) > 0:
@ -276,9 +290,10 @@ def load_pipeline(
) )
# blend and load unet # blend and load unet
unet = unet or path.join(model, "unet", ONNX_MODEL)
blended_unet = blend_loras( blended_unet = blend_loras(
server, server,
path.join(model, "unet", ONNX_MODEL), unet,
list(zip(lora_models, lora_weights)), list(zip(lora_models, lora_weights)),
"unet", "unet",
) )
@ -294,16 +309,6 @@ def load_pipeline(
) )
) )
# ControlNet component
if pipeline == "controlnet" and control is not None:
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
path.join(server.model_path, "control", f"{control.name}.onnx"),
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
)
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline) pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__) logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained( pipe = pipeline_class.from_pretrained(