disable latent mirroring for img2img pipelines
This commit is contained in:
parent
2773ab0965
commit
ddb8f4fadf
|
@ -22,7 +22,6 @@ from .patches.scheduler import SchedulerPatch
|
||||||
from .patches.unet import UNetWrapper
|
from .patches.unet import UNetWrapper
|
||||||
from .patches.vae import VAEWrapper
|
from .patches.vae import VAEWrapper
|
||||||
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||||
from .pipelines.highres import OnnxStableDiffusionHighresPipeline
|
|
||||||
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
||||||
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
||||||
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
|
from .pipelines.panorama_xl import ORTStableDiffusionXLPanoramaPipeline
|
||||||
|
@ -55,7 +54,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
available_pipelines = {
|
available_pipelines = {
|
||||||
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
"controlnet": OnnxStableDiffusionControlNetPipeline,
|
||||||
"highres": OnnxStableDiffusionHighresPipeline,
|
# "highres": OnnxStableDiffusionHighresPipeline,
|
||||||
"img2img": OnnxStableDiffusionImg2ImgPipeline,
|
"img2img": OnnxStableDiffusionImg2ImgPipeline,
|
||||||
"img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline,
|
"img2img-sdxl": ORTStableDiffusionXLImg2ImgPipeline,
|
||||||
"inpaint": OnnxStableDiffusionInpaintPipeline,
|
"inpaint": OnnxStableDiffusionInpaintPipeline,
|
||||||
|
@ -657,7 +656,7 @@ def patch_pipeline(
|
||||||
|
|
||||||
logger.debug("patching pipeline scheduler")
|
logger.debug("patching pipeline scheduler")
|
||||||
original_scheduler = pipe.scheduler
|
original_scheduler = pipe.scheduler
|
||||||
pipe.scheduler = SchedulerPatch(server, original_scheduler)
|
pipe.scheduler = SchedulerPatch(server, original_scheduler, params.is_txt2img())
|
||||||
|
|
||||||
logger.debug("patching pipeline UNet")
|
logger.debug("patching pipeline UNet")
|
||||||
original_unet = pipe.unet
|
original_unet = pipe.unet
|
||||||
|
|
|
@ -13,11 +13,13 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
class SchedulerPatch:
|
class SchedulerPatch:
|
||||||
server: ServerContext
|
server: ServerContext
|
||||||
|
text_pipeline: bool
|
||||||
wrapped: Any
|
wrapped: Any
|
||||||
|
|
||||||
def __init__(self, server: ServerContext, scheduler):
|
def __init__(self, server: ServerContext, scheduler: Any, text_pipeline: bool):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.wrapped = scheduler
|
self.wrapped = scheduler
|
||||||
|
self.text_pipeline = text_pipeline
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
return getattr(self.wrapped, attr)
|
return getattr(self.wrapped, attr)
|
||||||
|
@ -27,7 +29,7 @@ class SchedulerPatch:
|
||||||
) -> SchedulerOutput:
|
) -> SchedulerOutput:
|
||||||
result = self.wrapped.step(model_output, timestep, sample)
|
result = self.wrapped.step(model_output, timestep, sample)
|
||||||
|
|
||||||
if self.server.has_feature("mirror-latents"):
|
if self.text_pipeline and self.server.has_feature("mirror-latents"):
|
||||||
logger.info("using experimental latent mirroring")
|
logger.info("using experimental latent mirroring")
|
||||||
|
|
||||||
white_point = 0
|
white_point = 0
|
||||||
|
|
|
@ -321,6 +321,9 @@ class ImageParams:
|
||||||
def is_pix2pix(self):
|
def is_pix2pix(self):
|
||||||
return self.pipeline == "pix2pix"
|
return self.pipeline == "pix2pix"
|
||||||
|
|
||||||
|
def is_txt2img(self):
|
||||||
|
return self.pipeline in ["txt2img", "txt2img-sdxl"]
|
||||||
|
|
||||||
def is_xl(self):
|
def is_xl(self):
|
||||||
return self.pipeline.endswith("-sdxl")
|
return self.pipeline.endswith("-sdxl")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue