fix(api): correct gradient for vertical latent mirroring
This commit is contained in:
parent
c2ec5f9b2c
commit
bdbe6549bc
|
@ -32,15 +32,20 @@ class SchedulerPatch:
|
||||||
if self.text_pipeline and self.server.has_feature("mirror-latents"):
|
if self.text_pipeline and self.server.has_feature("mirror-latents"):
|
||||||
logger.info("using experimental latent mirroring")
|
logger.info("using experimental latent mirroring")
|
||||||
|
|
||||||
axis_of_symmetry = 3
|
|
||||||
if self.server.has_feature("mirror-latents-vertical"):
|
if self.server.has_feature("mirror-latents-vertical"):
|
||||||
axis_of_symmetry = 2
|
axis_of_symmetry = 2
|
||||||
|
expand_dims = (0, 1, 3)
|
||||||
|
else:
|
||||||
|
axis_of_symmetry = 3
|
||||||
|
expand_dims = (0, 1, 2)
|
||||||
|
|
||||||
white_point = 0
|
white_point = 0
|
||||||
black_point = result.prev_sample.shape[axis_of_symmetry] // 8
|
black_point = result.prev_sample.shape[axis_of_symmetry] // 8
|
||||||
center_line = result.prev_sample.shape[axis_of_symmetry] // 2
|
center_line = result.prev_sample.shape[axis_of_symmetry] // 2
|
||||||
|
|
||||||
gradient = linear_gradient(white_point, black_point, center_line)
|
gradient = linear_gradient(
|
||||||
|
white_point, black_point, center_line, expand_dims
|
||||||
|
)
|
||||||
latents = result.prev_sample.numpy()
|
latents = result.prev_sample.numpy()
|
||||||
|
|
||||||
gradiated_latents = np.multiply(latents, gradient)
|
gradiated_latents = np.multiply(latents, gradient)
|
||||||
|
@ -70,12 +75,13 @@ def linear_gradient(
|
||||||
white_point: int,
|
white_point: int,
|
||||||
black_point: int,
|
black_point: int,
|
||||||
center_line: int,
|
center_line: int,
|
||||||
|
expand_dims: tuple[int, ...] = (0, 1, 2),
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
||||||
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
|
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
|
||||||
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
|
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
|
||||||
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
||||||
return np.expand_dims(gradient, (0, 1, 2))
|
return np.expand_dims(gradient, expand_dims)
|
||||||
|
|
||||||
|
|
||||||
def mirror_latents(
|
def mirror_latents(
|
||||||
|
|
Loading…
Reference in New Issue