1
0
Fork 0

feat(api): add tiled VAE wrapper

This commit is contained in:
Sean Sube 2023-04-27 22:50:11 -05:00
parent eef85aabb8
commit 64a753e064
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 251 additions and 97 deletions

View File

@ -15,6 +15,8 @@ from ..models.meta import NetworkModel
from ..params import DeviceParams
from ..server import ServerContext
from ..utils import run_gc
from .patches.unet import UNetWrapper
from .patches.vae import VAEWrapper
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
@ -397,92 +399,6 @@ def optimize_pipeline(
logger.warning("error while enabling memory efficient attention: %s", e)
# TODO: does this need to change for fp16 modes?
timestep_dtype = np.float32
class UNetWrapper(object):
prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0
server: ServerContext
wrapped: OnnxRuntimeModel
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
):
self.server = server
self.wrapped = wrapped
def __call__(
self,
sample: np.ndarray = None,
timestep: np.ndarray = None,
encoder_hidden_states: np.ndarray = None,
**kwargs,
):
global timestep_dtype
timestep_dtype = timestep.dtype
logger.trace(
"UNet parameter types: %s, %s, %s",
sample.dtype,
timestep.dtype,
encoder_hidden_states.dtype,
)
if self.prompt_embeds is not None:
step_index = self.prompt_index % len(self.prompt_embeds)
logger.trace("multiple prompt embeds found, using step: %s", step_index)
encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
return self.wrapped(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
**kwargs,
)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
)
self.prompt_embeds = prompt_embeds
self.prompt_index = 0
class VAEWrapper(object):
def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped
def __call__(self, latent_sample=None, **kwargs):
global timestep_dtype
logger.trace("VAE parameter types: %s", latent_sample.dtype)
if latent_sample.dtype != timestep_dtype:
logger.info("converting VAE sample dtype")
latent_sample = latent_sample.astype(timestep_dtype)
return self.wrapped(latent_sample=latent_sample, **kwargs)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def patch_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
@ -496,7 +412,7 @@ def patch_pipeline(
if hasattr(pipe, "vae_decoder"):
original_vae = pipe.vae_decoder
pipe.vae_decoder = VAEWrapper(server, original_vae)
pipe.vae_decoder = VAEWrapper(server, original_vae, decoder=True)
elif hasattr(pipe, "vae"):
pass # TODO: current wrapper does not work with upscaling VAE
else:

View File

@ -0,0 +1,72 @@
from diffusers import OnnxRuntimeModel
from logging import getLogger
from typing import List, Optional
from ...server import ServerContext
from .vae import set_vae_dtype
import numpy as np
logger = getLogger(__name__)
class UNetWrapper(object):
prompt_embeds: Optional[List[np.ndarray]] = None
prompt_index: int = 0
server: ServerContext
wrapped: OnnxRuntimeModel
def __init__(
self,
server: ServerContext,
wrapped: OnnxRuntimeModel,
):
self.server = server
self.wrapped = wrapped
def __call__(
self,
sample: np.ndarray = None,
timestep: np.ndarray = None,
encoder_hidden_states: np.ndarray = None,
**kwargs,
):
logger.trace(
"UNet parameter types: %s, %s, %s",
sample.dtype,
timestep.dtype,
encoder_hidden_states.dtype,
)
set_vae_dtype(timestep.dtype)
if self.prompt_embeds is not None:
step_index = self.prompt_index % len(self.prompt_embeds)
logger.trace("multiple prompt embeds found, using step: %s", step_index)
encoder_hidden_states = self.prompt_embeds[step_index]
self.prompt_index += 1
if sample.dtype != timestep.dtype:
logger.trace("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
if encoder_hidden_states.dtype != timestep.dtype:
logger.trace("converting UNet hidden states to timestep dtype")
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
return self.wrapped(
sample=sample,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
**kwargs,
)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def set_prompts(self, prompt_embeds: List[np.ndarray]):
logger.debug(
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
)
self.prompt_embeds = prompt_embeds
self.prompt_index = 0

View File

@ -0,0 +1,165 @@
from typing import Union
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
from diffusers.models.vae import DiagonalGaussianDistribution, DecoderOutput
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
from ...server import ServerContext
from diffusers import OnnxRuntimeModel
logger = getLogger(__name__)
LATENT_CHANNELS = 4
SAMPLE_SIZE = 32
# TODO: does this need to change for fp16 modes?
timestep_dtype = np.float32
def set_vae_dtype(dtype):
global timestep_dtype
timestep_dtype = dtype
class VAEWrapper(object):
def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool):
self.server = server
self.wrapped = wrapped
self.decoder = decoder
self.tile_sample_min_size = SAMPLE_SIZE
self.tile_latent_min_size = int(SAMPLE_SIZE / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1)
self.post_quant_conv = nn.Conv2d(LATENT_CHANNELS, LATENT_CHANNELS, 1)
def __call__(self, latent_sample=None, **kwargs):
global timestep_dtype
logger.trace("VAE %s parameter types: %s", ("decoder" if self.decoder else "encoder"), latent_sample.dtype)
if latent_sample.dtype != timestep_dtype:
logger.info("converting VAE sample dtype")
latent_sample = latent_sample.astype(timestep_dtype)
if self.decoder:
return self.tiled_decode(latent_sample, **kwargs)
else:
return self.tiled_encode(latent_sample, **kwargs)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def blend_v(self, a, b, blend_extent):
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a, b, blend_extent):
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
"""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
posterior = posterior.numpy()
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""Decode a batch of images using a tiled decoder.
Args:
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
look of the output, but they should be much less noticeable.
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
`True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
if isinstance(z, np.ndarray):
z = torch.from_numpy(z)
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
tile = self.post_quant_conv(tile)
decoded = self(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
dec = dec.numpy()
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@ -283,7 +283,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
def get_views(self, panorama_height, panorama_width, window_size=32, stride=8):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
panorama_height /= 8
panorama_width /= 8

View File

@ -11,15 +11,15 @@ from ..params import ImageParams, Size
logger = getLogger(__name__)
latent_channels = 4
latent_factor = 8
LATENT_CHANNELS = 4
LATENT_FACTOR = 8
MAX_TOKENS_PER_GROUP = 77
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):(-?[\.|\d]+)\>")
LORA_TOKEN = compile(r"\<lora:([-\w]+):(-?[\.|\d]+)\>")
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
ALTERNATIVE_RANGE = compile(r"\(([\w\|]+)\)")
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
def expand_interval_ranges(prompt: str) -> str:
@ -253,9 +253,9 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
"""
latents_shape = (
batch,
latent_channels,
size.height // latent_factor,
size.width // latent_factor,
LATENT_CHANNELS,
size.height // LATENT_FACTOR,
size.width // LATENT_FACTOR,
)
rng = np.random.default_rng(seed)
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
@ -266,9 +266,9 @@ def get_tile_latents(
full_latents: np.ndarray, dims: Tuple[int, int, int]
) -> np.ndarray:
x, y, tile = dims
t = tile // latent_factor
x = x // latent_factor
y = y // latent_factor
t = tile // LATENT_FACTOR
x = x // LATENT_FACTOR
y = y // LATENT_FACTOR
xt = x + t
yt = y + t

View File

@ -14,6 +14,7 @@
"settings": {
"cSpell.words": [
"astype",
"Autoencoder",
"basicsr",
"Civitai",
"ckpt",