2024-01-22 03:36:39 +00:00
|
|
|
from typing import Any, Literal
|
|
|
|
|
|
|
|
import numpy as np
|
2024-01-22 03:53:45 +00:00
|
|
|
import torch
|
2024-01-22 03:49:24 +00:00
|
|
|
from diffusers.schedulers.scheduling_utils import SchedulerOutput
|
2024-01-22 03:36:39 +00:00
|
|
|
from torch import FloatTensor, Tensor
|
|
|
|
|
|
|
|
|
|
|
|
class SchedulerPatch:
|
2024-01-22 03:45:51 +00:00
|
|
|
wrapped: Any
|
2024-01-22 03:36:39 +00:00
|
|
|
|
|
|
|
def __init__(self, scheduler):
|
2024-01-22 03:45:51 +00:00
|
|
|
self.wrapped = scheduler
|
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
return getattr(self.wrapped, attr)
|
2024-01-22 03:36:39 +00:00
|
|
|
|
|
|
|
def step(
|
|
|
|
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
|
2024-01-22 03:49:24 +00:00
|
|
|
) -> SchedulerOutput:
|
2024-01-22 03:45:51 +00:00
|
|
|
result = self.wrapped.step(model_output, timestep, sample)
|
2024-01-22 03:36:39 +00:00
|
|
|
|
2024-01-22 03:53:45 +00:00
|
|
|
white_point = result.prev_sample.shape[2] // 8
|
|
|
|
black_point = result.prev_sample.shape[2] // 4
|
2024-01-22 03:36:39 +00:00
|
|
|
center_line = result.prev_sample.shape[2] // 2
|
2024-01-22 04:27:34 +00:00
|
|
|
# direction = "horizontal"
|
2024-01-22 03:36:39 +00:00
|
|
|
|
2024-01-22 04:27:34 +00:00
|
|
|
gradient = linear_gradient(white_point, black_point, center_line)
|
|
|
|
latents = result.prev_sample.numpy()
|
|
|
|
latents += np.mean([latents, np.flip(latents, axis=3)], axis=0) * gradient
|
|
|
|
|
|
|
|
# mirrored_latents = mirror_latents(
|
|
|
|
# result.prev_sample.numpy(), gradient, center_line, direction
|
|
|
|
# )
|
2024-01-22 03:36:39 +00:00
|
|
|
|
2024-01-22 03:49:24 +00:00
|
|
|
return SchedulerOutput(
|
2024-01-22 04:27:34 +00:00
|
|
|
prev_sample=torch.from_numpy(latents),
|
2024-01-22 03:36:39 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2024-01-22 04:27:34 +00:00
|
|
|
def linear_gradient(
|
2024-01-22 03:36:39 +00:00
|
|
|
white_point: int,
|
2024-01-22 03:45:51 +00:00
|
|
|
black_point: int,
|
2024-01-22 03:36:39 +00:00
|
|
|
center_line: int,
|
|
|
|
) -> np.ndarray:
|
2024-01-22 03:45:51 +00:00
|
|
|
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
|
2024-01-22 04:27:34 +00:00
|
|
|
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
|
|
|
|
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
|
2024-01-22 03:36:39 +00:00
|
|
|
gradient = np.reshape([gradient, np.flip(gradient)], -1)
|
2024-01-22 04:27:34 +00:00
|
|
|
return np.expand_dims(gradient, (0, 1, 2))
|
|
|
|
|
2024-01-22 03:36:39 +00:00
|
|
|
|
2024-01-22 04:27:34 +00:00
|
|
|
def mirror_latents(
|
|
|
|
latents: np.ndarray,
|
|
|
|
gradient: np.ndarray,
|
|
|
|
center_line: int,
|
|
|
|
direction: Literal["horizontal", "vertical"],
|
|
|
|
) -> np.ndarray:
|
2024-01-22 03:36:39 +00:00
|
|
|
if direction == "horizontal":
|
|
|
|
pad_left = max(0, -center_line)
|
|
|
|
pad_right = max(0, 2 * center_line - latents.shape[3])
|
|
|
|
|
|
|
|
# create the symmetrical copies
|
|
|
|
padded_array = np.pad(
|
|
|
|
latents, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant"
|
|
|
|
)
|
|
|
|
flipped_array = np.flip(padded_array, axis=3)
|
|
|
|
|
|
|
|
# apply the gradient to both copies
|
|
|
|
padded_gradiated = np.multiply(padded_array, gradient)
|
|
|
|
flipped_gradiated = np.multiply(flipped_array, gradient)
|
|
|
|
|
|
|
|
# produce masks
|
|
|
|
mask = np.ones_like(latents).astype(np.float32)
|
|
|
|
padded_mask = np.pad(
|
|
|
|
mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant"
|
|
|
|
)
|
2024-01-22 04:05:58 +00:00
|
|
|
flipped_mask = np.flip(padded_mask, axis=3)
|
|
|
|
|
|
|
|
padded_mask += flipped_mask
|
2024-01-22 03:36:39 +00:00
|
|
|
padded_mask += np.multiply(np.ones_like(padded_array), gradient)
|
|
|
|
|
|
|
|
# combine the two copies
|
|
|
|
result = padded_array + padded_gradiated + flipped_gradiated
|
|
|
|
result = np.where(padded_mask > 0, result / padded_mask, result)
|
|
|
|
return result[:, :, :, pad_left : pad_left + latents.shape[3]]
|
|
|
|
elif direction == "vertical":
|
|
|
|
pad_top = max(0, -center_line)
|
|
|
|
pad_bottom = max(0, 2 * center_line - latents.shape[2])
|
|
|
|
|
|
|
|
# create the symmetrical copies
|
|
|
|
padded_array = np.pad(
|
|
|
|
latents, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant"
|
|
|
|
)
|
|
|
|
flipped_array = np.flip(padded_array, axis=2)
|
|
|
|
|
|
|
|
# apply the gradient to both copies
|
|
|
|
padded_gradiated = np.multiply(
|
|
|
|
padded_array.transpose(0, 1, 3, 2), gradient
|
|
|
|
).transpose(0, 1, 3, 2)
|
|
|
|
flipped_gradiated = np.multiply(
|
|
|
|
flipped_array.transpose(0, 1, 3, 2), gradient
|
|
|
|
).transpose(0, 1, 3, 2)
|
|
|
|
|
|
|
|
# produce masks
|
|
|
|
mask = np.ones_like(latents).astype(np.float32)
|
|
|
|
padded_mask = np.pad(
|
|
|
|
mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant"
|
|
|
|
)
|
2024-01-22 04:05:58 +00:00
|
|
|
flipped_mask = np.flip(padded_mask, axis=2)
|
|
|
|
|
|
|
|
padded_mask += flipped_mask
|
2024-01-22 03:36:39 +00:00
|
|
|
padded_mask += np.multiply(
|
|
|
|
np.ones_like(padded_array).transpose(0, 1, 3, 2), gradient
|
|
|
|
).transpose(0, 1, 3, 2)
|
|
|
|
|
|
|
|
# combine the two copies
|
|
|
|
result = padded_array + padded_gradiated + flipped_gradiated
|
|
|
|
result = np.where(padded_mask > 0, result / padded_mask, result)
|
|
|
|
return flipped_array[:, :, pad_top : pad_top + latents.shape[2], :]
|
|
|
|
else:
|
|
|
|
raise ValueError("Invalid direction. Must be 'horizontal' or 'vertical'.")
|