diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py index 83b6ffb9..49b58793 100644 --- a/api/onnx_web/diffusers/patches/scheduler.py +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -32,15 +32,20 @@ class SchedulerPatch: if self.text_pipeline and self.server.has_feature("mirror-latents"): logger.info("using experimental latent mirroring") - axis_of_symmetry = 3 if self.server.has_feature("mirror-latents-vertical"): axis_of_symmetry = 2 + expand_dims = (0, 1, 3) + else: + axis_of_symmetry = 3 + expand_dims = (0, 1, 2) white_point = 0 black_point = result.prev_sample.shape[axis_of_symmetry] // 8 center_line = result.prev_sample.shape[axis_of_symmetry] // 2 - gradient = linear_gradient(white_point, black_point, center_line) + gradient = linear_gradient( + white_point, black_point, center_line, expand_dims + ) latents = result.prev_sample.numpy() gradiated_latents = np.multiply(latents, gradient) @@ -70,12 +75,13 @@ def linear_gradient( white_point: int, black_point: int, center_line: int, + expand_dims: tuple[int, ...] = (0, 1, 2), ) -> np.ndarray: gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32) gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1) gradient = np.pad(gradient, (0, center_line - black_point), mode="constant") gradient = np.reshape([gradient, np.flip(gradient)], -1) - return np.expand_dims(gradient, (0, 1, 2)) + return np.expand_dims(gradient, expand_dims) def mirror_latents(