load UNet with ControlNet weights as needed
This commit is contained in:
parent
06d55b0f1f
commit
c2f8ce5814
|
@ -199,7 +199,21 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# shared components
|
||||||
text_encoder = None
|
text_encoder = None
|
||||||
|
unet = None
|
||||||
|
|
||||||
|
# ControlNet component
|
||||||
|
if pipeline == "controlnet" and control is not None:
|
||||||
|
components["controlnet"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
path.join(server.model_path, "control", f"{control.name}.onnx"),
|
||||||
|
provider=device.ort_provider(),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
unet = path.join(model, "cnet", ONNX_MODEL)
|
||||||
|
|
||||||
# Textual Inversion blending
|
# Textual Inversion blending
|
||||||
if inversions is not None and len(inversions) > 0:
|
if inversions is not None and len(inversions) > 0:
|
||||||
|
@ -276,9 +290,10 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
|
|
||||||
# blend and load unet
|
# blend and load unet
|
||||||
|
unet = unet or path.join(model, "unet", ONNX_MODEL)
|
||||||
blended_unet = blend_loras(
|
blended_unet = blend_loras(
|
||||||
server,
|
server,
|
||||||
path.join(model, "unet", ONNX_MODEL),
|
unet,
|
||||||
list(zip(lora_models, lora_weights)),
|
list(zip(lora_models, lora_weights)),
|
||||||
"unet",
|
"unet",
|
||||||
)
|
)
|
||||||
|
@ -294,16 +309,6 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# ControlNet component
|
|
||||||
if pipeline == "controlnet" and control is not None:
|
|
||||||
components["controlnet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
path.join(server.model_path, "control", f"{control.name}.onnx"),
|
|
||||||
provider=device.ort_provider(),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||||
pipe = pipeline_class.from_pretrained(
|
pipe = pipeline_class.from_pretrained(
|
||||||
|
|
Loading…
Reference in New Issue