2023-04-28 03:50:11 +00:00
|
|
|
from logging import getLogger
|
2023-04-28 18:56:36 +00:00
|
|
|
from typing import Union
|
2023-04-28 03:50:11 +00:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2023-04-28 18:56:36 +00:00
|
|
|
from diffusers import OnnxRuntimeModel
|
|
|
|
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
2023-04-29 21:00:28 +00:00
|
|
|
from diffusers.models.vae import DecoderOutput
|
2023-06-06 04:16:08 +00:00
|
|
|
from onnx.helper import tensor_dtype_to_np_dtype
|
2023-04-28 03:50:11 +00:00
|
|
|
|
|
|
|
from ...server import ServerContext
|
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
LATENT_CHANNELS = 4
|
|
|
|
|
|
|
|
|
|
|
|
class VAEWrapper(object):
|
2023-05-02 04:20:40 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
server: ServerContext,
|
|
|
|
wrapped: OnnxRuntimeModel,
|
|
|
|
decoder: bool,
|
2023-05-04 00:15:05 +00:00
|
|
|
window: int,
|
|
|
|
overlap: float,
|
2023-05-02 04:20:40 +00:00
|
|
|
):
|
2023-04-28 03:50:11 +00:00
|
|
|
self.server = server
|
|
|
|
self.wrapped = wrapped
|
|
|
|
self.decoder = decoder
|
2023-05-04 00:15:05 +00:00
|
|
|
self.tiled = False
|
|
|
|
self.set_window_size(window, overlap)
|
2023-05-03 00:57:26 +00:00
|
|
|
|
2023-05-04 00:15:05 +00:00
|
|
|
def set_tiled(self, tiled: bool = True):
|
|
|
|
self.tiled = tiled
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-05-04 00:15:05 +00:00
|
|
|
def set_window_size(self, window: int, overlap: float):
|
|
|
|
self.tile_latent_min_size = window
|
|
|
|
self.tile_sample_min_size = window * 8
|
|
|
|
self.tile_overlap_factor = overlap
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-04-28 21:06:39 +00:00
|
|
|
def __call__(self, latent_sample=None, sample=None, **kwargs):
|
2023-06-06 04:16:08 +00:00
|
|
|
# set timestep dtype to input type
|
|
|
|
inputs = self.wrapped.model.graph.input
|
|
|
|
sample_input = [i for i in inputs if i.name == "sample" or i.name == "latent_sample"][0]
|
|
|
|
sample_dtype = tensor_dtype_to_np_dtype(sample_input.type.tensor_type.elem_type)
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-04-28 18:56:36 +00:00
|
|
|
logger.trace(
|
2023-04-28 21:06:39 +00:00
|
|
|
"VAE %s parameter types: %s, %s",
|
2023-04-28 18:56:36 +00:00
|
|
|
("decoder" if self.decoder else "encoder"),
|
2023-04-28 21:06:39 +00:00
|
|
|
(latent_sample.dtype if latent_sample is not None else "none"),
|
|
|
|
(sample.dtype if sample is not None else "none"),
|
2023-04-28 18:56:36 +00:00
|
|
|
)
|
2023-04-28 21:06:39 +00:00
|
|
|
|
2023-06-06 04:16:08 +00:00
|
|
|
if latent_sample is not None and latent_sample.dtype != sample_dtype:
|
|
|
|
logger.debug("converting VAE latent sample dtype to %s", sample_dtype)
|
|
|
|
latent_sample = latent_sample.astype(sample_dtype)
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-06-06 04:16:08 +00:00
|
|
|
if sample is not None and sample.dtype != sample_dtype:
|
|
|
|
logger.debug("converting VAE sample dtype to %s", sample_dtype)
|
|
|
|
sample = sample.astype(sample_dtype)
|
2023-04-28 21:06:39 +00:00
|
|
|
|
2023-05-04 00:15:05 +00:00
|
|
|
if self.tiled:
|
2023-05-02 04:20:40 +00:00
|
|
|
if self.decoder:
|
|
|
|
return self.tiled_decode(latent_sample, **kwargs)
|
|
|
|
else:
|
|
|
|
return self.tiled_encode(sample, **kwargs)
|
2023-04-28 03:50:11 +00:00
|
|
|
else:
|
2023-05-02 04:20:40 +00:00
|
|
|
if self.decoder:
|
|
|
|
return self.wrapped(latent_sample=latent_sample)
|
|
|
|
else:
|
|
|
|
return self.wrapped(sample=sample)
|
2023-04-28 03:50:11 +00:00
|
|
|
|
|
|
|
def __getattr__(self, attr):
|
|
|
|
return getattr(self.wrapped, attr)
|
|
|
|
|
|
|
|
def blend_v(self, a, b, blend_extent):
|
|
|
|
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
|
2023-04-28 18:56:36 +00:00
|
|
|
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[
|
|
|
|
:, :, y, :
|
|
|
|
] * (y / blend_extent)
|
2023-04-28 03:50:11 +00:00
|
|
|
return b
|
|
|
|
|
|
|
|
def blend_h(self, a, b, blend_extent):
|
|
|
|
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
|
2023-04-28 18:56:36 +00:00
|
|
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[
|
|
|
|
:, :, :, x
|
|
|
|
] * (x / blend_extent)
|
2023-04-28 03:50:11 +00:00
|
|
|
return b
|
|
|
|
|
2023-04-28 04:41:43 +00:00
|
|
|
@torch.no_grad()
|
2023-04-28 18:56:36 +00:00
|
|
|
def tiled_encode(
|
|
|
|
self, x: torch.FloatTensor, return_dict: bool = True
|
|
|
|
) -> AutoencoderKLOutput:
|
2023-04-28 03:50:11 +00:00
|
|
|
r"""Encode a batch of images using a tiled encoder.
|
|
|
|
Args:
|
|
|
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
|
|
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
|
|
|
|
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
|
|
|
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
|
|
|
look of the output, but they should be much less noticeable.
|
|
|
|
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
|
|
|
|
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
|
|
|
|
"""
|
|
|
|
if isinstance(x, np.ndarray):
|
|
|
|
x = torch.from_numpy(x)
|
|
|
|
|
|
|
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
|
|
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
|
|
|
row_limit = self.tile_latent_min_size - blend_extent
|
|
|
|
|
|
|
|
# Split the image into 512x512 tiles and encode them separately.
|
|
|
|
rows = []
|
|
|
|
for i in range(0, x.shape[2], overlap_size):
|
|
|
|
row = []
|
|
|
|
for j in range(0, x.shape[3], overlap_size):
|
2023-04-28 18:56:36 +00:00
|
|
|
tile = x[
|
|
|
|
:,
|
|
|
|
:,
|
|
|
|
i : i + self.tile_sample_min_size,
|
|
|
|
j : j + self.tile_sample_min_size,
|
|
|
|
]
|
2023-04-29 01:37:59 +00:00
|
|
|
tile = torch.from_numpy(self.wrapped(sample=tile.numpy())[0])
|
2023-04-28 03:50:11 +00:00
|
|
|
row.append(tile)
|
|
|
|
rows.append(row)
|
|
|
|
result_rows = []
|
|
|
|
for i, row in enumerate(rows):
|
|
|
|
result_row = []
|
|
|
|
for j, tile in enumerate(row):
|
|
|
|
# blend the above tile and the left tile
|
|
|
|
# to the current tile and add the current tile to the result row
|
|
|
|
if i > 0:
|
|
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
|
|
if j > 0:
|
|
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
|
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
|
|
|
result_rows.append(torch.cat(result_row, dim=3))
|
|
|
|
|
2023-04-29 21:02:09 +00:00
|
|
|
moments = torch.cat(result_rows, dim=2).numpy()
|
2023-04-28 03:50:11 +00:00
|
|
|
if not return_dict:
|
2023-04-29 20:59:39 +00:00
|
|
|
return (moments,)
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-04-29 20:59:39 +00:00
|
|
|
return AutoencoderKLOutput(latent_dist=moments)
|
2023-04-28 03:50:11 +00:00
|
|
|
|
2023-04-28 04:41:43 +00:00
|
|
|
@torch.no_grad()
|
2023-04-28 18:56:36 +00:00
|
|
|
def tiled_decode(
|
|
|
|
self, z: torch.FloatTensor, return_dict: bool = True
|
|
|
|
) -> Union[DecoderOutput, torch.FloatTensor]:
|
2023-04-28 03:50:11 +00:00
|
|
|
r"""Decode a batch of images using a tiled decoder.
|
|
|
|
Args:
|
|
|
|
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
|
|
|
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
|
|
|
|
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
|
|
|
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
|
|
|
look of the output, but they should be much less noticeable.
|
|
|
|
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
|
|
|
|
`True`):
|
|
|
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
|
|
|
"""
|
|
|
|
if isinstance(z, np.ndarray):
|
|
|
|
z = torch.from_numpy(z)
|
|
|
|
|
|
|
|
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
|
|
|
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
|
|
|
row_limit = self.tile_sample_min_size - blend_extent
|
|
|
|
|
|
|
|
# Split z into overlapping 64x64 tiles and decode them separately.
|
|
|
|
# The tiles have an overlap to avoid seams between tiles.
|
|
|
|
rows = []
|
|
|
|
for i in range(0, z.shape[2], overlap_size):
|
|
|
|
row = []
|
|
|
|
for j in range(0, z.shape[3], overlap_size):
|
2023-04-28 18:56:36 +00:00
|
|
|
tile = z[
|
|
|
|
:,
|
|
|
|
:,
|
|
|
|
i : i + self.tile_latent_min_size,
|
|
|
|
j : j + self.tile_latent_min_size,
|
|
|
|
]
|
2023-04-28 18:30:37 +00:00
|
|
|
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
|
2023-04-28 03:50:11 +00:00
|
|
|
row.append(decoded)
|
|
|
|
rows.append(row)
|
2023-04-28 18:30:37 +00:00
|
|
|
|
2023-04-28 03:50:11 +00:00
|
|
|
result_rows = []
|
|
|
|
for i, row in enumerate(rows):
|
|
|
|
result_row = []
|
|
|
|
for j, tile in enumerate(row):
|
|
|
|
# blend the above tile and the left tile
|
|
|
|
# to the current tile and add the current tile to the result row
|
|
|
|
if i > 0:
|
|
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
|
|
if j > 0:
|
|
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
|
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
|
|
|
result_rows.append(torch.cat(result_row, dim=3))
|
|
|
|
|
|
|
|
dec = torch.cat(result_rows, dim=2)
|
|
|
|
dec = dec.numpy()
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
return (dec,)
|
|
|
|
|
2023-04-28 18:56:36 +00:00
|
|
|
return DecoderOutput(sample=dec)
|