load UNet with ControlNet weights as needed
This commit is contained in:
parent
06d55b0f1f
commit
c2f8ce5814
|
@ -199,7 +199,21 @@ def load_pipeline(
|
|||
)
|
||||
}
|
||||
|
||||
# shared components
|
||||
text_encoder = None
|
||||
unet = None
|
||||
|
||||
# ControlNet component
|
||||
if pipeline == "controlnet" and control is not None:
|
||||
components["controlnet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
path.join(server.model_path, "control", f"{control.name}.onnx"),
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
|
||||
unet = path.join(model, "cnet", ONNX_MODEL)
|
||||
|
||||
# Textual Inversion blending
|
||||
if inversions is not None and len(inversions) > 0:
|
||||
|
@ -276,9 +290,10 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
# blend and load unet
|
||||
unet = unet or path.join(model, "unet", ONNX_MODEL)
|
||||
blended_unet = blend_loras(
|
||||
server,
|
||||
path.join(model, "unet", ONNX_MODEL),
|
||||
unet,
|
||||
list(zip(lora_models, lora_weights)),
|
||||
"unet",
|
||||
)
|
||||
|
@ -294,16 +309,6 @@ def load_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# ControlNet component
|
||||
if pipeline == "controlnet" and control is not None:
|
||||
components["controlnet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
path.join(server.model_path, "control", f"{control.name}.onnx"),
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
|
||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||
pipe = pipeline_class.from_pretrained(
|
||||
|
|
Loading…
Reference in New Issue