1
0
Fork 0

disable latent mirroring for img2img pipelines

This commit is contained in:
Sean Sube 2024-01-27 14:18:40 -06:00
parent 2773ab0965
commit ddb8f4fadf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 9 additions and 5 deletions

View File

@ -22,7 +22,6 @@ from .patches.scheduler import SchedulerPatch
from .patches.unet import UNetWrapper from .patches.unet import UNetWrapper
from .patches.vae import VAEWrapper from .patches.vae import VAEWrapper
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.highres import OnnxStableDiffusionHighresPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
@ -55,7 +54,7 @@ logger = getLogger(__name__)
available_pipelines = { available_pipelines = {
"controlnet": OnnxStableDiffusionControlNetPipeline, "controlnet": OnnxStableDiffusionControlNetPipeline,
"highres": OnnxStableDiffusionHighresPipeline, # "highres": OnnxStableDiffusionHighresPipeline,
"img2img": OnnxStableDiffusionImg2ImgPipeline, "img2img": OnnxStableDiffusionImg2ImgPipeline,
"img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline, "img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline,
"inpaint": OnnxStableDiffusionInpaintPipeline, "inpaint": OnnxStableDiffusionInpaintPipeline,
@ -657,7 +656,7 @@ def patch_pipeline(
logger.debug("patching pipeline scheduler") logger.debug("patching pipeline scheduler")
original_scheduler = pipe.scheduler original_scheduler = pipe.scheduler
pipe.scheduler = SchedulerPatch(server, original_scheduler) pipe.scheduler = SchedulerPatch(server, original_scheduler, params.is_txt2img())
logger.debug("patching pipeline UNet") logger.debug("patching pipeline UNet")
original_unet = pipe.unet original_unet = pipe.unet

View File

@ -13,11 +13,13 @@ logger = getLogger(__name__)
class SchedulerPatch: class SchedulerPatch:
server: ServerContext server: ServerContext
text_pipeline: bool
wrapped: Any wrapped: Any
def __init__(self, server: ServerContext, scheduler): def __init__(self, server: ServerContext, scheduler: Any, text_pipeline: bool):
self.server = server self.server = server
self.wrapped = scheduler self.wrapped = scheduler
self.text_pipeline = text_pipeline
def __getattr__(self, attr): def __getattr__(self, attr):
return getattr(self.wrapped, attr) return getattr(self.wrapped, attr)
@ -27,7 +29,7 @@ class SchedulerPatch:
) -> SchedulerOutput: ) -> SchedulerOutput:
result = self.wrapped.step(model_output, timestep, sample) result = self.wrapped.step(model_output, timestep, sample)
if self.server.has_feature("mirror-latents"): if self.text_pipeline and self.server.has_feature("mirror-latents"):
logger.info("using experimental latent mirroring") logger.info("using experimental latent mirroring")
white_point = 0 white_point = 0

View File

@ -321,6 +321,9 @@ class ImageParams:
def is_pix2pix(self): def is_pix2pix(self):
return self.pipeline == "pix2pix" return self.pipeline == "pix2pix"
def is_txt2img(self):
return self.pipeline in ["txt2img", "txt2img-sdxl"]
def is_xl(self): def is_xl(self):
return self.pipeline.endswith("-sdxl") return self.pipeline.endswith("-sdxl")