fix(api): correct gradient for vertical latent mirroring
This commit is contained in:
parent
c2ec5f9b2c
commit
bdbe6549bc
|
@ -32,15 +32,20 @@ class SchedulerPatch:
|
|||
if self.text_pipeline and self.server.has_feature("mirror-latents"):
|
||||
logger.info("using experimental latent mirroring")
|
||||
|
||||
axis_of_symmetry = 3
|
||||
if self.server.has_feature("mirror-latents-vertical"):
|
||||
axis_of_symmetry = 2
|
||||
expand_dims = (0, 1, 3)
|
||||
else:
|
||||
axis_of_symmetry = 3
|
||||
expand_dims = (0, 1, 2)
|
||||
|
||||
white_point = 0
|
||||
black_point = result.prev_sample.shape[axis_of_symmetry] // 8
|
||||
center_line = result.prev_sample.shape[axis_of_symmetry] // 2
|
||||
|
||||
gradient = linear_gradient(white_point, black_point, center_line)
|
||||
gradient = linear_gradient(
|
||||
white_point, black_point, center_line, expand_dims
|
||||
)
|
||||
latents = result.prev_sample.numpy()
|
||||
|
||||
gradiated_latents = np.multiply(latents, gradient)
|
||||
|
@ -70,12 +75,13 @@ def linear_gradient(
|
|||
white_point: int,
|
||||
black_point: int,
|
||||
center_line: int,
|
||||
expand_dims: tuple[int, ...] = (0, 1, 2),
|
||||
) -> np.ndarray:
|
||||
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
||||
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
|
||||
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
|
||||
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
||||
return np.expand_dims(gradient, (0, 1, 2))
|
||||
return np.expand_dims(gradient, expand_dims)
|
||||
|
||||
|
||||
def mirror_latents(
|
||||
|
|
Loading…
Reference in New Issue