use pipeline type to decide scheduler patch options
This commit is contained in:
parent
ddb8f4fadf
commit
d8a0f5f15f
|
@ -642,6 +642,19 @@ def optimize_pipeline(
|
||||||
logger.warning("error while enabling memory efficient attention: %s", e)
|
logger.warning("error while enabling memory efficient attention: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
IMAGE_PIPELINES = [
|
||||||
|
OnnxStableDiffusionControlNetPipeline,
|
||||||
|
OnnxStableDiffusionImg2ImgPipeline,
|
||||||
|
OnnxStableDiffusionInpaintPipeline,
|
||||||
|
OnnxStableDiffusionInstructPix2PixPipeline,
|
||||||
|
OnnxStableDiffusionLongPromptWeightingPipeline,
|
||||||
|
OnnxStableDiffusionPanoramaPipeline,
|
||||||
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
|
ORTStableDiffusionXLImg2ImgPipeline,
|
||||||
|
ORTStableDiffusionXLPanoramaPipeline,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_pipeline(
|
def patch_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
|
@ -654,9 +667,15 @@ def patch_pipeline(
|
||||||
logger.debug("patching prompt encoder")
|
logger.debug("patching prompt encoder")
|
||||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||||
|
|
||||||
logger.debug("patching pipeline scheduler")
|
# the pipeline requested in params may not be the one currently being used, especially during the later img2img
|
||||||
|
# stages of a highres pipeline, so we need to check the pipeline type
|
||||||
|
is_text_pipeline = type(pipe) not in IMAGE_PIPELINES
|
||||||
|
logger.debug(
|
||||||
|
"patching pipeline scheduler for %s pipeline",
|
||||||
|
"txt2img" if is_text_pipeline else "img2img",
|
||||||
|
)
|
||||||
original_scheduler = pipe.scheduler
|
original_scheduler = pipe.scheduler
|
||||||
pipe.scheduler = SchedulerPatch(server, original_scheduler, params.is_txt2img())
|
pipe.scheduler = SchedulerPatch(server, original_scheduler, is_text_pipeline)
|
||||||
|
|
||||||
logger.debug("patching pipeline UNet")
|
logger.debug("patching pipeline UNet")
|
||||||
original_unet = pipe.unet
|
original_unet = pipe.unet
|
||||||
|
|
|
@ -321,9 +321,6 @@ class ImageParams:
|
||||||
def is_pix2pix(self):
|
def is_pix2pix(self):
|
||||||
return self.pipeline == "pix2pix"
|
return self.pipeline == "pix2pix"
|
||||||
|
|
||||||
def is_txt2img(self):
|
|
||||||
return self.pipeline in ["txt2img", "txt2img-sdxl"]
|
|
||||||
|
|
||||||
def is_xl(self):
|
def is_xl(self):
|
||||||
return self.pipeline.endswith("-sdxl")
|
return self.pipeline.endswith("-sdxl")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue