diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index a66e6ffe..5d34d8f8 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -642,6 +642,19 @@ def optimize_pipeline( logger.warning("error while enabling memory efficient attention: %s", e) +IMAGE_PIPELINES = [ + OnnxStableDiffusionControlNetPipeline, + OnnxStableDiffusionImg2ImgPipeline, + OnnxStableDiffusionInpaintPipeline, + OnnxStableDiffusionInstructPix2PixPipeline, + OnnxStableDiffusionLongPromptWeightingPipeline, + OnnxStableDiffusionPanoramaPipeline, + OnnxStableDiffusionUpscalePipeline, + ORTStableDiffusionXLImg2ImgPipeline, + ORTStableDiffusionXLPanoramaPipeline, +] + + def patch_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, @@ -654,9 +667,15 @@ def patch_pipeline( logger.debug("patching prompt encoder") pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) - logger.debug("patching pipeline scheduler") + # the pipeline requested in params may not be the one currently being used, especially during the later img2img + # stages of a highres pipeline, so we need to check the pipeline type + is_text_pipeline = type(pipe) not in IMAGE_PIPELINES + logger.debug( + "patching pipeline scheduler for %s pipeline", + "txt2img" if is_text_pipeline else "img2img", + ) original_scheduler = pipe.scheduler - pipe.scheduler = SchedulerPatch(server, original_scheduler, params.is_txt2img()) + pipe.scheduler = SchedulerPatch(server, original_scheduler, is_text_pipeline) logger.debug("patching pipeline UNet") original_unet = pipe.unet diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 584bfb38..c523750e 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -321,9 +321,6 @@ class ImageParams: def is_pix2pix(self): return self.pipeline == "pix2pix" - def is_txt2img(self): - return self.pipeline in ["txt2img", "txt2img-sdxl"] - def is_xl(self): return self.pipeline.endswith("-sdxl")