include flipped mask
This commit is contained in:
parent
e41fa04fe9
commit
23018a79a3
|
@ -67,6 +67,9 @@ def mirror_latents(
|
||||||
padded_mask = np.pad(
|
padded_mask = np.pad(
|
||||||
mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant"
|
mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant"
|
||||||
)
|
)
|
||||||
|
flipped_mask = np.flip(padded_mask, axis=3)
|
||||||
|
|
||||||
|
padded_mask += flipped_mask
|
||||||
padded_mask += np.multiply(np.ones_like(padded_array), gradient)
|
padded_mask += np.multiply(np.ones_like(padded_array), gradient)
|
||||||
|
|
||||||
# combine the two copies
|
# combine the two copies
|
||||||
|
@ -96,6 +99,9 @@ def mirror_latents(
|
||||||
padded_mask = np.pad(
|
padded_mask = np.pad(
|
||||||
mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant"
|
mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant"
|
||||||
)
|
)
|
||||||
|
flipped_mask = np.flip(padded_mask, axis=2)
|
||||||
|
|
||||||
|
padded_mask += flipped_mask
|
||||||
padded_mask += np.multiply(
|
padded_mask += np.multiply(
|
||||||
np.ones_like(padded_array).transpose(0, 1, 3, 2), gradient
|
np.ones_like(padded_array).transpose(0, 1, 3, 2), gradient
|
||||||
).transpose(0, 1, 3, 2)
|
).transpose(0, 1, 3, 2)
|
||||||
|
|
Loading…
Reference in New Issue