apply LoRAs to correct UNet
This commit is contained in:
parent
33f72187c5
commit
30968f7c33
|
@ -197,19 +197,21 @@ def load_pipeline(
|
|||
|
||||
# shared components
|
||||
text_encoder = None
|
||||
unet = None
|
||||
unet_type = "unet"
|
||||
|
||||
# ControlNet component
|
||||
if pipeline == "controlnet" and control is not None:
|
||||
cnet_path = path.join(server.model_path, "control", f"{control.name}.onnx")
|
||||
logger.debug("loading ControlNet weights from %s", cnet_path)
|
||||
components["controlnet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
path.join(server.model_path, "control", f"{control.name}.onnx"),
|
||||
cnet_path,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
)
|
||||
|
||||
unet = path.join(model, "cnet", ONNX_MODEL)
|
||||
unet_type = "cnet"
|
||||
|
||||
# Textual Inversion blending
|
||||
if inversions is not None and len(inversions) > 0:
|
||||
|
@ -286,7 +288,7 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
# blend and load unet
|
||||
unet = unet or path.join(model, "unet", ONNX_MODEL)
|
||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||
blended_unet = blend_loras(
|
||||
server,
|
||||
unet,
|
||||
|
@ -305,6 +307,18 @@ def load_pipeline(
|
|||
)
|
||||
)
|
||||
|
||||
# make sure a UNet has been loaded
|
||||
if "unet" not in components:
|
||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||
logger.debug("loading UNet (%s) from %s", unet_type, unet)
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
unet,
|
||||
provider=device.ort_provider(),
|
||||
sess_options=unet_opts,
|
||||
)
|
||||
)
|
||||
|
||||
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
|
||||
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
|
||||
pipe = pipeline_class.from_pretrained(
|
||||
|
|
Loading…
Reference in New Issue