1
0
Fork 0

apply LoRAs to correct UNet

This commit is contained in:
Sean Sube 2023-04-15 12:28:55 -05:00
parent 33f72187c5
commit 30968f7c33
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 18 additions and 4 deletions

View File

@ -197,19 +197,21 @@ def load_pipeline(
# shared components
text_encoder = None
unet = None
unet_type = "unet"
# ControlNet component
if pipeline == "controlnet" and control is not None:
cnet_path = path.join(server.model_path, "control", f"{control.name}.onnx")
logger.debug("loading ControlNet weights from %s", cnet_path)
components["controlnet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
path.join(server.model_path, "control", f"{control.name}.onnx"),
cnet_path,
provider=device.ort_provider(),
sess_options=device.sess_options(),
)
)
unet = path.join(model, "cnet", ONNX_MODEL)
unet_type = "cnet"
# Textual Inversion blending
if inversions is not None and len(inversions) > 0:
@ -286,7 +288,7 @@ def load_pipeline(
)
# blend and load unet
unet = unet or path.join(model, "unet", ONNX_MODEL)
unet = path.join(model, unet_type, ONNX_MODEL)
blended_unet = blend_loras(
server,
unet,
@ -305,6 +307,18 @@ def load_pipeline(
)
)
# make sure a UNet has been loaded
if "unet" not in components:
unet = path.join(model, unet_type, ONNX_MODEL)
logger.debug("loading UNet (%s) from %s", unet_type, unet)
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet,
provider=device.ort_provider(),
sess_options=unet_opts,
)
)
pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(