1
0
Fork 0

use pipeline type to decide scheduler patch options

This commit is contained in:
Sean Sube 2024-01-27 17:48:14 -06:00
parent ddb8f4fadf
commit d8a0f5f15f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 21 additions and 5 deletions

View File

@ -642,6 +642,19 @@ def optimize_pipeline(
logger.warning("error while enabling memory efficient attention: %s", e) logger.warning("error while enabling memory efficient attention: %s", e)
IMAGE_PIPELINES = [
OnnxStableDiffusionControlNetPipeline,
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline,
OnnxStableDiffusionInstructPix2PixPipeline,
OnnxStableDiffusionLongPromptWeightingPipeline,
OnnxStableDiffusionPanoramaPipeline,
OnnxStableDiffusionUpscalePipeline,
ORTStableDiffusionXLImg2ImgPipeline,
ORTStableDiffusionXLPanoramaPipeline,
]
def patch_pipeline( def patch_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
@ -654,9 +667,15 @@ def patch_pipeline(
logger.debug("patching prompt encoder") logger.debug("patching prompt encoder")
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
logger.debug("patching pipeline scheduler") # the pipeline requested in params may not be the one currently being used, especially during the later img2img
# stages of a highres pipeline, so we need to check the pipeline type
is_text_pipeline = type(pipe) not in IMAGE_PIPELINES
logger.debug(
"patching pipeline scheduler for %s pipeline",
"txt2img" if is_text_pipeline else "img2img",
)
original_scheduler = pipe.scheduler original_scheduler = pipe.scheduler
pipe.scheduler = SchedulerPatch(server, original_scheduler, params.is_txt2img()) pipe.scheduler = SchedulerPatch(server, original_scheduler, is_text_pipeline)
logger.debug("patching pipeline UNet") logger.debug("patching pipeline UNet")
original_unet = pipe.unet original_unet = pipe.unet

View File

@ -321,9 +321,6 @@ class ImageParams:
def is_pix2pix(self): def is_pix2pix(self):
return self.pipeline == "pix2pix" return self.pipeline == "pix2pix"
def is_txt2img(self):
return self.pipeline in ["txt2img", "txt2img-sdxl"]
def is_xl(self): def is_xl(self):
return self.pipeline.endswith("-sdxl") return self.pipeline.endswith("-sdxl")