diff --git a/api/onnx_web/diffusers/patches/scheduler.py b/api/onnx_web/diffusers/patches/scheduler.py index 2f9a358f..ac3658d0 100644 --- a/api/onnx_web/diffusers/patches/scheduler.py +++ b/api/onnx_web/diffusers/patches/scheduler.py @@ -67,6 +67,9 @@ def mirror_latents( padded_mask = np.pad( mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant" ) + flipped_mask = np.flip(padded_mask, axis=3) + + padded_mask += flipped_mask padded_mask += np.multiply(np.ones_like(padded_array), gradient) # combine the two copies @@ -96,6 +99,9 @@ def mirror_latents( padded_mask = np.pad( mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant" ) + flipped_mask = np.flip(padded_mask, axis=2) + + padded_mask += flipped_mask padded_mask += np.multiply( np.ones_like(padded_array).transpose(0, 1, 3, 2), gradient ).transpose(0, 1, 3, 2)