disable latent mirroring for img2img pipelines
This commit is contained in:
parent
2773ab0965
commit
ddb8f4fadf
|
@ -22,7 +22,6 @@ from .patches.scheduler import SchedulerPatch
|
|||
from .patches.unet import UNetWrapper
|
||||
from .patches.vae import VAEWrapper
|
||||
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from .pipelines.highres import OnnxStableDiffusionHighresPipeline
|
||||
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
||||
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
||||
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
|
||||
|
@ -55,7 +54,7 @@ logger = getLogger(__name__)
|
|||
|
||||
available_pipelines = {
|
||||
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
||||
"highres": OnnxStableDiffusionHighresPipeline,
|
||||
# "highres": OnnxStableDiffusionHighresPipeline,
|
||||
"img2img": OnnxStableDiffusionImg2ImgPipeline,
|
||||
"img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline,
|
||||
"inpaint": OnnxStableDiffusionInpaintPipeline,
|
||||
|
@ -657,7 +656,7 @@ def patch_pipeline(
|
|||
|
||||
logger.debug("patching pipeline scheduler")
|
||||
original_scheduler = pipe.scheduler
|
||||
pipe.scheduler = SchedulerPatch(server, original_scheduler)
|
||||
pipe.scheduler = SchedulerPatch(server, original_scheduler, params.is_txt2img())
|
||||
|
||||
logger.debug("patching pipeline UNet")
|
||||
original_unet = pipe.unet
|
||||
|
|
|
@ -13,11 +13,13 @@ logger = getLogger(__name__)
|
|||
|
||||
class SchedulerPatch:
|
||||
server: ServerContext
|
||||
text_pipeline: bool
|
||||
wrapped: Any
|
||||
|
||||
def __init__(self, server: ServerContext, scheduler):
|
||||
def __init__(self, server: ServerContext, scheduler: Any, text_pipeline: bool):
|
||||
self.server = server
|
||||
self.wrapped = scheduler
|
||||
self.text_pipeline = text_pipeline
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
@ -27,7 +29,7 @@ class SchedulerPatch:
|
|||
) -> SchedulerOutput:
|
||||
result = self.wrapped.step(model_output, timestep, sample)
|
||||
|
||||
if self.server.has_feature("mirror-latents"):
|
||||
if self.text_pipeline and self.server.has_feature("mirror-latents"):
|
||||
logger.info("using experimental latent mirroring")
|
||||
|
||||
white_point = 0
|
||||
|
|
|
@ -321,6 +321,9 @@ class ImageParams:
|
|||
def is_pix2pix(self):
|
||||
return self.pipeline == "pix2pix"
|
||||
|
||||
def is_txt2img(self):
|
||||
return self.pipeline in ["txt2img", "txt2img-sdxl"]
|
||||
|
||||
def is_xl(self):
|
||||
return self.pipeline.endswith("-sdxl")
|
||||
|
||||
|
|
Loading…
Reference in New Issue