1
0
Fork 0

add gradient to mirroring

This commit is contained in:
Sean Sube 2024-01-21 22:27:34 -06:00
parent 23018a79a3
commit bdf6f401a6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 22 additions and 14 deletions

View File

@ -23,31 +23,39 @@ class SchedulerPatch:
white_point = result.prev_sample.shape[2] // 8 white_point = result.prev_sample.shape[2] // 8
black_point = result.prev_sample.shape[2] // 4 black_point = result.prev_sample.shape[2] // 4
center_line = result.prev_sample.shape[2] // 2 center_line = result.prev_sample.shape[2] // 2
direction = "horizontal" # direction = "horizontal"
mirrored_latents = mirror_latents( gradient = linear_gradient(white_point, black_point, center_line)
result.prev_sample.numpy(), white_point, black_point, center_line, direction latents = result.prev_sample.numpy()
) latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient
# mirrored_latents = mirror_latents(
# result.prev_sample.numpy(), gradient, center_line, direction
# )
return SchedulerOutput( return SchedulerOutput(
prev_sample=torch.from_numpy(mirrored_latents), prev_sample=torch.from_numpy(latents),
) )
def linear_gradient(
white_point: int,
black_point: int,
center_line: int,
) -> np.ndarray:
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
gradient = np.reshape([gradient, np.flip(gradient)], -1)
return np.expand_dims(gradient, (0, 1, 2))
def mirror_latents( def mirror_latents(
latents: np.ndarray, latents: np.ndarray,
white_point: int, gradient: np.ndarray,
black_point: int,
center_line: int, center_line: int,
direction: Literal["horizontal", "vertical"], direction: Literal["horizontal", "vertical"],
) -> np.ndarray: ) -> np.ndarray:
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
gradient = np.pad(
gradient, (white_point, center_line - black_point), mode="constant"
)
gradient = np.reshape([gradient, np.flip(gradient)], -1)
gradient = np.expand_dims(gradient, (0, 1, 2))
if direction == "horizontal": if direction == "horizontal":
pad_left = max(0, -center_line) pad_left = max(0, -center_line)
pad_right = max(0, 2 * center_line - latents.shape[3]) pad_right = max(0, 2 * center_line - latents.shape[3])