1
0
Fork 0
onnx-web/api/onnx_web/diffusers/patches/vae.py

210 lines
8.3 KiB
Python
Raw Normal View History

2023-04-28 03:50:11 +00:00
from logging import getLogger
2023-04-28 18:56:36 +00:00
from typing import Union
2023-04-28 03:50:11 +00:00
import numpy as np
import torch
2023-04-28 18:56:36 +00:00
from diffusers import OnnxRuntimeModel
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
2023-04-29 21:00:28 +00:00
from diffusers.models.vae import DecoderOutput
2023-06-06 04:26:11 +00:00
from diffusers.pipelines.onnx_utils import ORT_TO_NP_TYPE
2023-04-28 03:50:11 +00:00
from ...server import ServerContext
logger = getLogger(__name__)
LATENT_CHANNELS = 4
class VAEWrapper(object):
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
decoder: bool,
window: int,
overlap: float,
):
2023-04-28 03:50:11 +00:00
self.server = server
self.wrapped = wrapped
self.decoder = decoder
self.tiled = False
self.set_window_size(window, overlap)
def set_tiled(self, tiled: bool = True):
self.tiled = tiled
2023-04-28 03:50:11 +00:00
def set_window_size(self, window: int, overlap: float):
self.tile_latent_min_size = window
self.tile_sample_min_size = window * 8
self.tile_overlap_factor = overlap
2023-04-28 03:50:11 +00:00
2023-04-28 21:06:39 +00:00
def __call__(self, latent_sample=None, sample=None, **kwargs):
2023-09-10 16:26:18 +00:00
model = self.wrapped.model if hasattr(self.wrapped, "model") else self.wrapped.session
# set timestep dtype to input type
2023-06-06 04:26:11 +00:00
sample_dtype = next(
(
input.type
2023-09-10 16:26:18 +00:00
for input in model.get_inputs()
2023-06-06 04:26:11 +00:00
if input.name == "sample" or input.name == "latent_sample"
),
"tensor(float)",
)
sample_dtype = ORT_TO_NP_TYPE[sample_dtype]
2023-04-28 03:50:11 +00:00
2023-04-28 18:56:36 +00:00
logger.trace(
2023-04-28 21:06:39 +00:00
"VAE %s parameter types: %s, %s",
2023-04-28 18:56:36 +00:00
("decoder" if self.decoder else "encoder"),
2023-04-28 21:06:39 +00:00
(latent_sample.dtype if latent_sample is not None else "none"),
(sample.dtype if sample is not None else "none"),
2023-04-28 18:56:36 +00:00
)
2023-04-28 21:06:39 +00:00
if latent_sample is not None and latent_sample.dtype != sample_dtype:
logger.debug("converting VAE latent sample dtype to %s", sample_dtype)
latent_sample = latent_sample.astype(sample_dtype)
2023-04-28 03:50:11 +00:00
if sample is not None and sample.dtype != sample_dtype:
logger.debug("converting VAE sample dtype to %s", sample_dtype)
sample = sample.astype(sample_dtype)
2023-04-28 21:06:39 +00:00
if self.tiled:
if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
else:
return self.tiled_encode(sample, **kwargs)
2023-04-28 03:50:11 +00:00
else:
if self.decoder:
return self.wrapped(latent_sample=latent_sample)
else:
return self.wrapped(sample=sample)
2023-04-28 03:50:11 +00:00
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def blend_v(self, a, b, blend_extent):
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
2023-04-28 18:56:36 +00:00
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[
:, :, y, :
] * (y / blend_extent)
2023-04-28 03:50:11 +00:00
return b
def blend_h(self, a, b, blend_extent):
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
2023-04-28 18:56:36 +00:00
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[
:, :, :, x
] * (x / blend_extent)
2023-04-28 03:50:11 +00:00
return b
@torch.no_grad()
2023-04-28 18:56:36 +00:00
def tiled_encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> AutoencoderKLOutput:
2023-04-28 03:50:11 +00:00
r"""Encode a batch of images using a tiled encoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
"""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
2023-04-28 18:56:36 +00:00
tile = x[
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = torch.from_numpy(self.wrapped(sample=tile.numpy())[0])
2023-04-28 03:50:11 +00:00
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
2023-04-29 21:02:09 +00:00
moments = torch.cat(result_rows, dim=2).numpy()
2023-04-28 03:50:11 +00:00
if not return_dict:
return (moments,)
2023-04-28 03:50:11 +00:00
return AutoencoderKLOutput(latent_dist=moments)
2023-04-28 03:50:11 +00:00
@torch.no_grad()
2023-04-28 18:56:36 +00:00
def tiled_decode(
self, z: torch.FloatTensor, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
2023-04-28 03:50:11 +00:00
r"""Decode a batch of images using a tiled decoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
`True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
if isinstance(z, np.ndarray):
z = torch.from_numpy(z)
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
2023-04-28 18:56:36 +00:00
tile = z[
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
2023-04-28 18:30:37 +00:00
decoded = torch.from_numpy(self.wrapped(latent_sample=tile.numpy())[0])
2023-04-28 03:50:11 +00:00
row.append(decoded)
rows.append(row)
2023-04-28 18:30:37 +00:00
2023-04-28 03:50:11 +00:00
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
dec = dec.numpy()
if not return_dict:
return (dec,)
2023-04-28 18:56:36 +00:00
return DecoderOutput(sample=dec)