1
0
Fork 0
onnx-web/api/onnx_web/diffusers/patches/scheduler.py

159 lines
5.6 KiB
Python
Raw Normal View History

from logging import getLogger
2024-01-22 03:36:39 +00:00
from typing import Any, Literal
import numpy as np
2024-01-22 03:53:45 +00:00
import torch
2024-01-22 03:49:24 +00:00
from diffusers.schedulers.scheduling_utils import SchedulerOutput
2024-01-22 03:36:39 +00:00
from torch import FloatTensor, Tensor
from ...server.context import ServerContext
logger = getLogger(__name__)
2024-01-22 03:36:39 +00:00
class SchedulerPatch:
server: ServerContext
text_pipeline: bool
2024-01-22 03:45:51 +00:00
wrapped: Any
2024-01-22 03:36:39 +00:00
def __init__(self, server: ServerContext, scheduler: Any, text_pipeline: bool):
self.server = server
2024-01-22 03:45:51 +00:00
self.wrapped = scheduler
self.text_pipeline = text_pipeline
2024-01-22 03:45:51 +00:00
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
2024-01-22 03:36:39 +00:00
def step(
self, model_output: FloatTensor, timestep: Tensor, sample: FloatTensor
2024-01-22 03:49:24 +00:00
) -> SchedulerOutput:
2024-01-22 03:45:51 +00:00
result = self.wrapped.step(model_output, timestep, sample)
2024-01-22 03:36:39 +00:00
if self.text_pipeline and self.server.has_feature("mirror-latents"):
logger.info("using experimental latent mirroring")
if self.server.has_feature("mirror-latents-vertical"):
axis_of_symmetry = 2
expand_dims = (0, 1, 3)
else:
axis_of_symmetry = 3
expand_dims = (0, 1, 2)
white_point = 2
black_point = result.prev_sample.shape[axis_of_symmetry] // 4
center_line = result.prev_sample.shape[axis_of_symmetry] // 2
gradient = linear_gradient(
white_point, black_point, center_line, expand_dims
)
latents = result.prev_sample.numpy()
# gradiated_latents = np.multiply(latents, gradient)
inverse_gradiated_latents = np.multiply(
np.flip(latents, axis=axis_of_symmetry), gradient
)
latents += inverse_gradiated_latents
2024-01-22 03:36:39 +00:00
mask = np.ones_like(latents).astype(np.float32)
# gradiated_mask = np.multiply(mask, gradient)
# flipping the mask would do nothing, we need to flip the gradient for this one
inverse_gradiated_mask = np.multiply(
mask, np.flip(gradient, axis=axis_of_symmetry)
)
mask += inverse_gradiated_mask
2024-01-22 04:27:34 +00:00
latents = np.where(mask > 0, latents / mask, latents)
2024-01-22 03:36:39 +00:00
return SchedulerOutput(
prev_sample=torch.from_numpy(latents),
)
else:
return result
2024-01-22 03:36:39 +00:00
2024-01-22 04:27:34 +00:00
def linear_gradient(
2024-01-22 03:36:39 +00:00
white_point: int,
2024-01-22 03:45:51 +00:00
black_point: int,
2024-01-22 03:36:39 +00:00
center_line: int,
expand_dims: tuple[int, ...] = (0, 1, 2),
2024-01-22 03:36:39 +00:00
) -> np.ndarray:
2024-01-22 03:45:51 +00:00
gradient = np.linspace(1, 0, black_point - white_point).astype(np.float32)
2024-01-22 04:27:34 +00:00
gradient = np.pad(gradient, (white_point, 0), mode="constant", constant_values=1)
gradient = np.pad(gradient, (0, center_line - black_point), mode="constant")
2024-01-22 03:36:39 +00:00
gradient = np.reshape([gradient, np.flip(gradient)], -1)
return np.expand_dims(gradient, expand_dims)
2024-01-22 04:27:34 +00:00
2024-01-22 03:36:39 +00:00
2024-01-22 04:27:34 +00:00
def mirror_latents(
latents: np.ndarray,
gradient: np.ndarray,
center_line: int,
direction: Literal["horizontal", "vertical"],
) -> np.ndarray:
2024-01-22 03:36:39 +00:00
if direction == "horizontal":
pad_left = max(0, -center_line)
pad_right = max(0, 2 * center_line - latents.shape[3])
# create the symmetrical copies
padded_array = np.pad(
latents, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant"
)
flipped_array = np.flip(padded_array, axis=3)
# apply the gradient to both copies
padded_gradiated = np.multiply(padded_array, gradient)
flipped_gradiated = np.multiply(flipped_array, gradient)
# produce masks
mask = np.ones_like(latents).astype(np.float32)
padded_mask = np.pad(
mask, ((0, 0), (0, 0), (0, 0), (pad_left, pad_right)), mode="constant"
)
2024-01-22 04:05:58 +00:00
flipped_mask = np.flip(padded_mask, axis=3)
padded_mask += np.multiply(padded_mask, gradient)
padded_mask += np.multiply(flipped_mask, gradient)
2024-01-22 03:36:39 +00:00
# combine the two copies
result = padded_array + padded_gradiated + flipped_gradiated
result = np.where(padded_mask > 0, result / padded_mask, result)
return result[:, :, :, pad_left : pad_left + latents.shape[3]]
elif direction == "vertical":
pad_top = max(0, -center_line)
pad_bottom = max(0, 2 * center_line - latents.shape[2])
# create the symmetrical copies
padded_array = np.pad(
latents, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant"
)
flipped_array = np.flip(padded_array, axis=2)
# apply the gradient to both copies
padded_gradiated = np.multiply(
padded_array.transpose(0, 1, 3, 2), gradient
).transpose(0, 1, 3, 2)
flipped_gradiated = np.multiply(
flipped_array.transpose(0, 1, 3, 2), gradient
).transpose(0, 1, 3, 2)
# produce masks
mask = np.ones_like(latents).astype(np.float32)
padded_mask = np.pad(
mask, ((0, 0), (0, 0), (pad_top, pad_bottom), (0, 0)), mode="constant"
)
2024-01-22 04:05:58 +00:00
flipped_mask = np.flip(padded_mask, axis=2)
2024-01-22 03:36:39 +00:00
padded_mask += np.multiply(
padded_mask.transpose(0, 1, 3, 2), gradient
).transpose(0, 1, 3, 2)
padded_mask += np.multiply(
flipped_mask.transpose(0, 1, 3, 2), gradient
2024-01-22 03:36:39 +00:00
).transpose(0, 1, 3, 2)
# combine the two copies
result = padded_array + padded_gradiated + flipped_gradiated
result = np.where(padded_mask > 0, result / padded_mask, result)
return flipped_array[:, :, pad_top : pad_top + latents.shape[2], :]
else:
raise ValueError("Invalid direction. Must be 'horizontal' or 'vertical'.")