add gradient to mirroring
This commit is contained in:
parent
23018a79a3
commit
bdf6f401a6
|
@ -23,31 +23,39 @@ class SchedulerPatch:
|
|||
white_point = result.prev_sample.shape[2] // 8
|
||||
black_point = result.prev_sample.shape[2] // 4
|
||||
center_line = result.prev_sample.shape[2] // 2
|
||||
direction = "horizontal"
|
||||
# direction = "horizontal"
|
||||
|
||||
mirrored_latents = mirror_latents(
|
||||
result.prev_sample.numpy(), white_point, black_point, center_line, direction
|
||||
)
|
||||
gradient = linear_gradient(white_point, black_point, center_line)
|
||||
latents = result.prev_sample.numpy()
|
||||
latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient
|
||||
|
||||
# mirrored_latents = mirror_latents(
|
||||
# result.prev_sample.numpy(), gradient, center_line, direction
|
||||
# )
|
||||
|
||||
return SchedulerOutput(
|
||||
prev_sample=torch.from_numpy(mirrored_latents),
|
||||
prev_sample=torch.from_numpy(latents),
|
||||
)
|
||||
|
||||
|
||||
def linear_gradient(
|
||||
white_point: int,
|
||||
black_point: int,
|
||||
center_line: int,
|
||||
) -> np.ndarray:
|
||||
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
||||
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
|
||||
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
|
||||
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
||||
return np.expand_dims(gradient, (0, 1, 2))
|
||||
|
||||
|
||||
def mirror_latents(
|
||||
latents: np.ndarray,
|
||||
white_point: int,
|
||||
black_point: int,
|
||||
gradient: np.ndarray,
|
||||
center_line: int,
|
||||
direction: Literal["horizontal", "vertical"],
|
||||
) -> np.ndarray:
|
||||
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
||||
gradient = np.pad(
|
||||
gradient, (white_point, center_line - black_point), mode="constant"
|
||||
)
|
||||
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
||||
gradient = np.expand_dims(gradient, (0, 1, 2))
|
||||
|
||||
if direction == "horizontal":
|
||||
pad_left = max(0, -center_line)
|
||||
pad_right = max(0, 2 * center_line - latents.shape[3])
|
||||
|
|
Loading…
Reference in New Issue