feat(api): add tiled VAE wrapper
This commit is contained in:
parent
eef85aabb8
commit
64a753e064
|
@ -15,6 +15,8 @@ from ..models.meta import NetworkModel
|
|||
from ..params import DeviceParams
|
||||
from ..server import ServerContext
|
||||
from ..utils import run_gc
|
||||
from .patches.unet import UNetWrapper
|
||||
from .patches.vae import VAEWrapper
|
||||
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
||||
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
||||
|
@ -397,92 +399,6 @@ def optimize_pipeline(
|
|||
logger.warning("error while enabling memory efficient attention: %s", e)
|
||||
|
||||
|
||||
# TODO: does this need to change for fp16 modes?
|
||||
timestep_dtype = np.float32
|
||||
|
||||
|
||||
class UNetWrapper(object):
|
||||
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||
prompt_index: int = 0
|
||||
server: ServerContext
|
||||
wrapped: OnnxRuntimeModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
wrapped: OnnxRuntimeModel,
|
||||
):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: np.ndarray = None,
|
||||
timestep: np.ndarray = None,
|
||||
encoder_hidden_states: np.ndarray = None,
|
||||
**kwargs,
|
||||
):
|
||||
global timestep_dtype
|
||||
timestep_dtype = timestep.dtype
|
||||
|
||||
logger.trace(
|
||||
"UNet parameter types: %s, %s, %s",
|
||||
sample.dtype,
|
||||
timestep.dtype,
|
||||
encoder_hidden_states.dtype,
|
||||
)
|
||||
|
||||
if self.prompt_embeds is not None:
|
||||
step_index = self.prompt_index % len(self.prompt_embeds)
|
||||
logger.trace("multiple prompt embeds found, using step: %s", step_index)
|
||||
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||
self.prompt_index += 1
|
||||
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
||||
if encoder_hidden_states.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet hidden states to timestep dtype")
|
||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
||||
|
||||
return self.wrapped(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
||||
logger.debug(
|
||||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
||||
)
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.prompt_index = 0
|
||||
|
||||
|
||||
class VAEWrapper(object):
|
||||
def __init__(self, server, wrapped):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(self, latent_sample=None, **kwargs):
|
||||
global timestep_dtype
|
||||
|
||||
logger.trace("VAE parameter types: %s", latent_sample.dtype)
|
||||
if latent_sample.dtype != timestep_dtype:
|
||||
logger.info("converting VAE sample dtype")
|
||||
latent_sample = latent_sample.astype(timestep_dtype)
|
||||
|
||||
return self.wrapped(latent_sample=latent_sample, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
|
||||
def patch_pipeline(
|
||||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
|
@ -496,7 +412,7 @@ def patch_pipeline(
|
|||
|
||||
if hasattr(pipe, "vae_decoder"):
|
||||
original_vae = pipe.vae_decoder
|
||||
pipe.vae_decoder = VAEWrapper(server, original_vae)
|
||||
pipe.vae_decoder = VAEWrapper(server, original_vae, decoder=True)
|
||||
elif hasattr(pipe, "vae"):
|
||||
pass # TODO: current wrapper does not work with upscaling VAE
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
from diffusers import OnnxRuntimeModel
|
||||
from logging import getLogger
|
||||
from typing import List, Optional
|
||||
|
||||
from ...server import ServerContext
|
||||
from .vae import set_vae_dtype
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class UNetWrapper(object):
|
||||
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||
prompt_index: int = 0
|
||||
server: ServerContext
|
||||
wrapped: OnnxRuntimeModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
wrapped: OnnxRuntimeModel,
|
||||
):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sample: np.ndarray = None,
|
||||
timestep: np.ndarray = None,
|
||||
encoder_hidden_states: np.ndarray = None,
|
||||
**kwargs,
|
||||
):
|
||||
logger.trace(
|
||||
"UNet parameter types: %s, %s, %s",
|
||||
sample.dtype,
|
||||
timestep.dtype,
|
||||
encoder_hidden_states.dtype,
|
||||
)
|
||||
set_vae_dtype(timestep.dtype)
|
||||
|
||||
if self.prompt_embeds is not None:
|
||||
step_index = self.prompt_index % len(self.prompt_embeds)
|
||||
logger.trace("multiple prompt embeds found, using step: %s", step_index)
|
||||
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||
self.prompt_index += 1
|
||||
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
||||
if encoder_hidden_states.dtype != timestep.dtype:
|
||||
logger.trace("converting UNet hidden states to timestep dtype")
|
||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
||||
|
||||
return self.wrapped(
|
||||
sample=sample,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
||||
logger.debug(
|
||||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
||||
)
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.prompt_index = 0
|
||||
|
|
@ -0,0 +1,165 @@
|
|||
from typing import Union
|
||||
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
||||
from diffusers.models.vae import DiagonalGaussianDistribution, DecoderOutput
|
||||
from logging import getLogger
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...server import ServerContext
|
||||
from diffusers import OnnxRuntimeModel
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
LATENT_CHANNELS = 4
|
||||
SAMPLE_SIZE = 32
|
||||
|
||||
# TODO: does this need to change for fp16 modes?
|
||||
timestep_dtype = np.float32
|
||||
|
||||
|
||||
def set_vae_dtype(dtype):
|
||||
global timestep_dtype
|
||||
timestep_dtype = dtype
|
||||
|
||||
|
||||
class VAEWrapper(object):
|
||||
def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
self.decoder = decoder
|
||||
|
||||
self.tile_sample_min_size = SAMPLE_SIZE
|
||||
self.tile_latent_min_size = int(SAMPLE_SIZE / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1)
|
||||
self.post_quant_conv = nn.Conv2d(LATENT_CHANNELS, LATENT_CHANNELS, 1)
|
||||
|
||||
def __call__(self, latent_sample=None, **kwargs):
|
||||
global timestep_dtype
|
||||
|
||||
logger.trace("VAE %s parameter types: %s", ("decoder" if self.decoder else "encoder"), latent_sample.dtype)
|
||||
if latent_sample.dtype != timestep_dtype:
|
||||
logger.info("converting VAE sample dtype")
|
||||
latent_sample = latent_sample.astype(timestep_dtype)
|
||||
|
||||
if self.decoder:
|
||||
return self.tiled_decode(latent_sample, **kwargs)
|
||||
else:
|
||||
return self.tiled_encode(latent_sample, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
def blend_v(self, a, b, blend_extent):
|
||||
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a, b, blend_extent):
|
||||
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
Args:
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
|
||||
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
look of the output, but they should be much less noticeable.
|
||||
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
|
||||
"""
|
||||
if isinstance(x, np.ndarray):
|
||||
x = torch.from_numpy(x)
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self(tile)
|
||||
tile = self.quant_conv(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
moments = torch.cat(result_rows, dim=2)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
posterior = posterior.numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""Decode a batch of images using a tiled decoder.
|
||||
Args:
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
|
||||
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
look of the output, but they should be much less noticeable.
|
||||
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
|
||||
`True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
if isinstance(z, np.ndarray):
|
||||
z = torch.from_numpy(z)
|
||||
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, z.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
tile = self.post_quant_conv(tile)
|
||||
decoded = self(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
dec = torch.cat(result_rows, dim=2)
|
||||
dec = dec.numpy()
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
|
@ -283,7 +283,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
|||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
|
||||
def get_views(self, panorama_height, panorama_width, window_size=32, stride=8):
|
||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||
panorama_height /= 8
|
||||
panorama_width /= 8
|
||||
|
|
|
@ -11,15 +11,15 @@ from ..params import ImageParams, Size
|
|||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
latent_channels = 4
|
||||
latent_factor = 8
|
||||
LATENT_CHANNELS = 4
|
||||
LATENT_FACTOR = 8
|
||||
MAX_TOKENS_PER_GROUP = 77
|
||||
|
||||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):(-?[\.|\d]+)\>")
|
||||
LORA_TOKEN = compile(r"\<lora:([-\w]+):(-?[\.|\d]+)\>")
|
||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||
ALTERNATIVE_RANGE = compile(r"\(([\w\|]+)\)")
|
||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||
|
||||
|
||||
def expand_interval_ranges(prompt: str) -> str:
|
||||
|
@ -253,9 +253,9 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
|||
"""
|
||||
latents_shape = (
|
||||
batch,
|
||||
latent_channels,
|
||||
size.height // latent_factor,
|
||||
size.width // latent_factor,
|
||||
LATENT_CHANNELS,
|
||||
size.height // LATENT_FACTOR,
|
||||
size.width // LATENT_FACTOR,
|
||||
)
|
||||
rng = np.random.default_rng(seed)
|
||||
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
||||
|
@ -266,9 +266,9 @@ def get_tile_latents(
|
|||
full_latents: np.ndarray, dims: Tuple[int, int, int]
|
||||
) -> np.ndarray:
|
||||
x, y, tile = dims
|
||||
t = tile // latent_factor
|
||||
x = x // latent_factor
|
||||
y = y // latent_factor
|
||||
t = tile // LATENT_FACTOR
|
||||
x = x // LATENT_FACTOR
|
||||
y = y // LATENT_FACTOR
|
||||
xt = x + t
|
||||
yt = y + t
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
"settings": {
|
||||
"cSpell.words": [
|
||||
"astype",
|
||||
"Autoencoder",
|
||||
"basicsr",
|
||||
"Civitai",
|
||||
"ckpt",
|
||||
|
|
Loading…
Reference in New Issue