1
0
Fork 0

disable latent mirroring for img2img pipelines

This commit is contained in:
Sean Sube 2024-01-27 14:18:40 -06:00
parent 2773ab0965
commit ddb8f4fadf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 9 additions and 5 deletions

View File

@ -22,7 +22,6 @@ from .patches.scheduler import SchedulerPatch
from .patches.unet import UNetWrapper
from .patches.vae import VAEWrapper
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.highres import OnnxStableDiffusionHighresPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
@ -55,7 +54,7 @@ logger = getLogger(__name__)
available_pipelines = {
"controlnet": OnnxStableDiffusionControlNetPipeline,
"highres": OnnxStableDiffusionHighresPipeline,
# "highres": OnnxStableDiffusionHighresPipeline,
"img2img": OnnxStableDiffusionImg2ImgPipeline,
"img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline,
"inpaint": OnnxStableDiffusionInpaintPipeline,
@ -657,7 +656,7 @@ def patch_pipeline(
logger.debug("patching pipeline scheduler")
original_scheduler = pipe.scheduler
pipe.scheduler = SchedulerPatch(server, original_scheduler)
pipe.scheduler = SchedulerPatch(server, original_scheduler, params.is_txt2img())
logger.debug("patching pipeline UNet")
original_unet = pipe.unet

View File

@ -13,11 +13,13 @@ logger = getLogger(__name__)
class SchedulerPatch:
server: ServerContext
text_pipeline: bool
wrapped: Any
def __init__(self, server: ServerContext, scheduler):
def __init__(self, server: ServerContext, scheduler: Any, text_pipeline: bool):
self.server = server
self.wrapped = scheduler
self.text_pipeline = text_pipeline
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
@ -27,7 +29,7 @@ class SchedulerPatch:
) -> SchedulerOutput:
result = self.wrapped.step(model_output, timestep, sample)
if self.server.has_feature("mirror-latents"):
if self.text_pipeline and self.server.has_feature("mirror-latents"):
logger.info("using experimental latent mirroring")
white_point = 0

View File

@ -321,6 +321,9 @@ class ImageParams:
def is_pix2pix(self):
return self.pipeline == "pix2pix"
def is_txt2img(self):
return self.pipeline in ["txt2img", "txt2img-sdxl"]
def is_xl(self):
return self.pipeline.endswith("-sdxl")