From bdf6f401a66144718e6ef5d69659c175373221c7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 21 Jan 2024 22:27:34 -0600 Subject: [PATCH] add gradient to mirroring --- api/onnx_web/diffusers/patches/scheduler.py | 36 +++++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py index ac3658d0..3f99ae6c 100644 --- a/api/onnx_web/diffusers/patches/scheduler.py +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -23,31 +23,39 @@ class SchedulerPatch: white_point = result.prev_sample.shape[2] // 8 black_point = result.prev_sample.shape[2] // 4 center_line = result.prev_sample.shape[2] // 2 - direction = "horizontal" + # direction = "horizontal" - mirrored_latents = mirror_latents( - result.prev_sample.numpy(), white_point, black_point, center_line, direction - ) + gradient = linear_gradient(white_point, black_point, center_line) + latents = result.prev_sample.numpy() + latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient + + # mirrored_latents = mirror_latents( + # result.prev_sample.numpy(), gradient, center_line, direction + # ) return SchedulerOutput( - prev_sample=torch.from_numpy(mirrored_latents), + prev_sample=torch.from_numpy(latents), ) +def linear_gradient( + white_point: int, + black_point: int, + center_line: int, +) -> np.ndarray: + gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32) + gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1) + gradient = np.pad(gradient, (0, center_line - black_point), mode="constant") + gradient = np.reshape([gradient, np.flip(gradient)], -1) + return np.expand_dims(gradient, (0, 1, 2)) + + def mirror_latents( latents: np.ndarray, - white_point: int, - black_point: int, + gradient: np.ndarray, center_line: int, direction: Literal["horizontal", "vertical"], ) -> np.ndarray: - gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32) - gradient = np.pad( - gradient, (white_point, center_line - black_point), mode="constant" - ) - gradient = np.reshape([gradient, np.flip(gradient)], -1) - gradient = np.expand_dims(gradient, (0, 1, 2)) - if direction == "horizontal": pad_left = max(0, -center_line) pad_right = max(0, 2 * center_line - latents.shape[3])