1
0
Fork 0

include flipped mask

This commit is contained in:
Sean Sube 2024-01-21 22:05:58 -06:00
parent e41fa04fe9
commit 23018a79a3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 6 additions and 0 deletions

View File

@ -67,6 +67,9 @@ def mirror_latents(
padded_mask = np.pad(
mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant"
)
flipped_mask = np.flip(padded_mask, axis=3)
padded_mask += flipped_mask
padded_mask += np.multiply(np.ones_like(padded_array), gradient)
# combine the two copies
@ -96,6 +99,9 @@ def mirror_latents(
padded_mask = np.pad(
mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant"
)
flipped_mask = np.flip(padded_mask, axis=2)
padded_mask += flipped_mask
padded_mask += np.multiply(
np.ones_like(padded_array).transpose(0, 1, 3, 2), gradient
).transpose(0, 1, 3, 2)