diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 5c1d1dc2..f123ddc1 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -652,11 +652,12 @@ def patch_pipeline( logger.debug("patching SD pipeline") if not params.is_lpw() and not params.is_xl(): + logger.debug("patching prompt encoder") pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) logger.debug("patching pipeline scheduler") original_scheduler = pipe.scheduler - pipe.scheduler = SchedulerPatch(original_scheduler) + pipe.scheduler = SchedulerPatch(server, original_scheduler) logger.debug("patching pipeline UNet") original_unet = pipe.unet diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py index 3f99ae6c..6492dcfd 100644 --- a/api/onnx_web/diffusers/patches/scheduler.py +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -1,3 +1,4 @@ +from logging import getLogger from typing import Any, Literal import numpy as np @@ -5,11 +6,17 @@ import torch from diffusers.schedulers.scheduling_utils import SchedulerOutput from torch import FloatTensor, Tensor +from ...server.context import ServerContext + +logger = getLogger(__name__) + class SchedulerPatch: + server: ServerContext wrapped: Any - def __init__(self, scheduler): + def __init__(self, server: ServerContext, scheduler): + self.server = server self.wrapped = scheduler def __getattr__(self, attr): @@ -20,22 +27,33 @@ class SchedulerPatch: ) -> SchedulerOutput: result = self.wrapped.step(model_output, timestep, sample) - white_point = result.prev_sample.shape[2] // 8 - black_point = result.prev_sample.shape[2] // 4 - center_line = result.prev_sample.shape[2] // 2 - # direction = "horizontal" + if self.server.has_feature("mirror-latents"): + logger.info("using experimental latent mirroring") - gradient = linear_gradient(white_point, black_point, center_line) - latents = result.prev_sample.numpy() - latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient + white_point = 0 + black_point = result.prev_sample.shape[2] // 8 + center_line = result.prev_sample.shape[2] // 2 - # mirrored_latents = mirror_latents( - # result.prev_sample.numpy(), gradient, center_line, direction - # ) + gradient = linear_gradient(white_point, black_point, center_line) + latents = result.prev_sample.numpy() - return SchedulerOutput( - prev_sample=torch.from_numpy(latents), - ) + gradiated_latents = np.multiply(latents, gradient) + inverse_gradiated_latents = np.multiply(np.flip(latents, axis=3), gradient) + latents += gradiated_latents + inverse_gradiated_latents + + mask = np.ones_like(latents).astype(np.float32) + gradiated_mask = np.multiply(mask, gradient) + # flipping the mask would do nothing, we need to flip the gradient for this one + inverse_gradiated_mask = np.multiply(mask, np.flip(gradient, axis=3)) + mask += gradiated_mask + inverse_gradiated_mask + + latents = np.where(mask > 0, latents / mask, latents) + + return SchedulerOutput( + prev_sample=torch.from_numpy(latents), + ) + else: + return result def linear_gradient(