add gradient to mirroring
This commit is contained in:
parent
23018a79a3
commit
bdf6f401a6
|
@ -23,31 +23,39 @@ class SchedulerPatch:
|
||||||
white_point = result.prev_sample.shape[2] // 8
|
white_point = result.prev_sample.shape[2] // 8
|
||||||
black_point = result.prev_sample.shape[2] // 4
|
black_point = result.prev_sample.shape[2] // 4
|
||||||
center_line = result.prev_sample.shape[2] // 2
|
center_line = result.prev_sample.shape[2] // 2
|
||||||
direction = "horizontal"
|
# direction = "horizontal"
|
||||||
|
|
||||||
mirrored_latents = mirror_latents(
|
gradient = linear_gradient(white_point, black_point, center_line)
|
||||||
result.prev_sample.numpy(), white_point, black_point, center_line, direction
|
latents = result.prev_sample.numpy()
|
||||||
)
|
latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient
|
||||||
|
|
||||||
|
# mirrored_latents = mirror_latents(
|
||||||
|
# result.prev_sample.numpy(), gradient, center_line, direction
|
||||||
|
# )
|
||||||
|
|
||||||
return SchedulerOutput(
|
return SchedulerOutput(
|
||||||
prev_sample=torch.from_numpy(mirrored_latents),
|
prev_sample=torch.from_numpy(latents),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_gradient(
|
||||||
|
white_point: int,
|
||||||
|
black_point: int,
|
||||||
|
center_line: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
||||||
|
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
|
||||||
|
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
|
||||||
|
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
||||||
|
return np.expand_dims(gradient, (0, 1, 2))
|
||||||
|
|
||||||
|
|
||||||
def mirror_latents(
|
def mirror_latents(
|
||||||
latents: np.ndarray,
|
latents: np.ndarray,
|
||||||
white_point: int,
|
gradient: np.ndarray,
|
||||||
black_point: int,
|
|
||||||
center_line: int,
|
center_line: int,
|
||||||
direction: Literal["horizontal", "vertical"],
|
direction: Literal["horizontal", "vertical"],
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
|
||||||
gradient = np.pad(
|
|
||||||
gradient, (white_point, center_line - black_point), mode="constant"
|
|
||||||
)
|
|
||||||
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
|
||||||
gradient = np.expand_dims(gradient, (0, 1, 2))
|
|
||||||
|
|
||||||
if direction == "horizontal":
|
if direction == "horizontal":
|
||||||
pad_left = max(0, -center_line)
|
pad_left = max(0, -center_line)
|
||||||
pad_right = max(0, 2 * center_line - latents.shape[3])
|
pad_right = max(0, 2 * center_line - latents.shape[3])
|
||||||
|
|
Loading…
Reference in New Issue