include flipped mask
This commit is contained in:
parent
e41fa04fe9
commit
23018a79a3
|
@ -67,6 +67,9 @@ def mirror_latents(
|
|||
padded_mask = np.pad(
|
||||
mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant"
|
||||
)
|
||||
flipped_mask = np.flip(padded_mask, axis=3)
|
||||
|
||||
padded_mask += flipped_mask
|
||||
padded_mask += np.multiply(np.ones_like(padded_array), gradient)
|
||||
|
||||
# combine the two copies
|
||||
|
@ -96,6 +99,9 @@ def mirror_latents(
|
|||
padded_mask = np.pad(
|
||||
mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant"
|
||||
)
|
||||
flipped_mask = np.flip(padded_mask, axis=2)
|
||||
|
||||
padded_mask += flipped_mask
|
||||
padded_mask += np.multiply(
|
||||
np.ones_like(padded_array).transpose(0, 1, 3, 2), gradient
|
||||
).transpose(0, 1, 3, 2)
|
||||
|
|
Loading…
Reference in New Issue