1
0
Fork 0

fix(api): correct gradient for vertical latent mirroring

This commit is contained in:
Sean Sube 2024-01-27 21:51:36 -06:00
parent c2ec5f9b2c
commit bdbe6549bc
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 3 deletions

View File

@ -32,15 +32,20 @@ class SchedulerPatch:
if self.text_pipeline and self.server.has_feature("mirror-latents"):
logger.info("using experimental latent mirroring")
axis_of_symmetry = 3
if self.server.has_feature("mirror-latents-vertical"):
axis_of_symmetry = 2
expand_dims = (0, 1, 3)
else:
axis_of_symmetry = 3
expand_dims = (0, 1, 2)
white_point = 0
black_point = result.prev_sample.shape[axis_of_symmetry] // 8
center_line = result.prev_sample.shape[axis_of_symmetry] // 2
gradient = linear_gradient(white_point, black_point, center_line)
gradient = linear_gradient(
white_point, black_point, center_line, expand_dims
)
latents = result.prev_sample.numpy()
gradiated_latents = np.multiply(latents, gradient)
@ -70,12 +75,13 @@ def linear_gradient(
white_point: int,
black_point: int,
center_line: int,
expand_dims: tuple[int, ...] = (0, 1, 2),
) -> np.ndarray:
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
gradient = np.reshape([gradient, np.flip(gradient)], -1)
return np.expand_dims(gradient, (0, 1, 2))
return np.expand_dims(gradient, expand_dims)
def mirror_latents(