diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index f123ddc1..a66e6ffe 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -22,7 +22,6 @@ from .patches.scheduler import SchedulerPatch from .patches.unet import UNetWrapper from .patches.vae import VAEWrapper from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline -from .pipelines.highres import OnnxStableDiffusionHighresPipeline from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline @@ -55,7 +54,7 @@ logger = getLogger(__name__) available_pipelines = { "controlnet": OnnxStableDiffusionControlNetPipeline, - "highres": OnnxStableDiffusionHighresPipeline, + # "highres": OnnxStableDiffusionHighresPipeline, "img2img": OnnxStableDiffusionImg2ImgPipeline, "img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline, "inpaint": OnnxStableDiffusionInpaintPipeline, @@ -657,7 +656,7 @@ def patch_pipeline( logger.debug("patching pipeline scheduler") original_scheduler = pipe.scheduler - pipe.scheduler = SchedulerPatch(server, original_scheduler) + pipe.scheduler = SchedulerPatch(server, original_scheduler, params.is_txt2img()) logger.debug("patching pipeline UNet") original_unet = pipe.unet diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py index 6492dcfd..e9527dd7 100644 --- a/api/onnx_web/diffusers/patches/scheduler.py +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -13,11 +13,13 @@ logger = getLogger(__name__) class SchedulerPatch: server: ServerContext + text_pipeline: bool wrapped: Any - def __init__(self, server: ServerContext, scheduler): + def __init__(self, server: ServerContext, scheduler: Any, text_pipeline: bool): self.server = server self.wrapped = scheduler + self.text_pipeline = text_pipeline def __getattr__(self, attr): return getattr(self.wrapped, attr) @@ -27,7 +29,7 @@ class SchedulerPatch: ) -> SchedulerOutput: result = self.wrapped.step(model_output, timestep, sample) - if self.server.has_feature("mirror-latents"): + if self.text_pipeline and self.server.has_feature("mirror-latents"): logger.info("using experimental latent mirroring") white_point = 0 diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index c523750e..584bfb38 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -321,6 +321,9 @@ class ImageParams: def is_pix2pix(self): return self.pipeline == "pix2pix" + def is_txt2img(self): + return self.pipeline in ["txt2img", "txt2img-sdxl"] + def is_xl(self): return self.pipeline.endswith("-sdxl")