feat(api): add tiled VAE wrapper
This commit is contained in:
parent
eef85aabb8
commit
64a753e064
|
@ -15,6 +15,8 @@ from ..models.meta import NetworkModel
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
from .patches.unet import UNetWrapper
|
||||||
|
from .patches.vae import VAEWrapper
|
||||||
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline
|
||||||
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline
|
||||||
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline
|
||||||
|
@ -397,92 +399,6 @@ def optimize_pipeline(
|
||||||
logger.warning("error while enabling memory efficient attention: %s", e)
|
logger.warning("error while enabling memory efficient attention: %s", e)
|
||||||
|
|
||||||
|
|
||||||
# TODO: does this need to change for fp16 modes?
|
|
||||||
timestep_dtype = np.float32
|
|
||||||
|
|
||||||
|
|
||||||
class UNetWrapper(object):
|
|
||||||
prompt_embeds: Optional[List[np.ndarray]] = None
|
|
||||||
prompt_index: int = 0
|
|
||||||
server: ServerContext
|
|
||||||
wrapped: OnnxRuntimeModel
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server: ServerContext,
|
|
||||||
wrapped: OnnxRuntimeModel,
|
|
||||||
):
|
|
||||||
self.server = server
|
|
||||||
self.wrapped = wrapped
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sample: np.ndarray = None,
|
|
||||||
timestep: np.ndarray = None,
|
|
||||||
encoder_hidden_states: np.ndarray = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
global timestep_dtype
|
|
||||||
timestep_dtype = timestep.dtype
|
|
||||||
|
|
||||||
logger.trace(
|
|
||||||
"UNet parameter types: %s, %s, %s",
|
|
||||||
sample.dtype,
|
|
||||||
timestep.dtype,
|
|
||||||
encoder_hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.prompt_embeds is not None:
|
|
||||||
step_index = self.prompt_index % len(self.prompt_embeds)
|
|
||||||
logger.trace("multiple prompt embeds found, using step: %s", step_index)
|
|
||||||
encoder_hidden_states = self.prompt_embeds[step_index]
|
|
||||||
self.prompt_index += 1
|
|
||||||
|
|
||||||
if sample.dtype != timestep.dtype:
|
|
||||||
logger.trace("converting UNet sample to timestep dtype")
|
|
||||||
sample = sample.astype(timestep.dtype)
|
|
||||||
|
|
||||||
if encoder_hidden_states.dtype != timestep.dtype:
|
|
||||||
logger.trace("converting UNet hidden states to timestep dtype")
|
|
||||||
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
|
||||||
|
|
||||||
return self.wrapped(
|
|
||||||
sample=sample,
|
|
||||||
timestep=timestep,
|
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.wrapped, attr)
|
|
||||||
|
|
||||||
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
|
||||||
logger.debug(
|
|
||||||
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
|
||||||
)
|
|
||||||
self.prompt_embeds = prompt_embeds
|
|
||||||
self.prompt_index = 0
|
|
||||||
|
|
||||||
|
|
||||||
class VAEWrapper(object):
|
|
||||||
def __init__(self, server, wrapped):
|
|
||||||
self.server = server
|
|
||||||
self.wrapped = wrapped
|
|
||||||
|
|
||||||
def __call__(self, latent_sample=None, **kwargs):
|
|
||||||
global timestep_dtype
|
|
||||||
|
|
||||||
logger.trace("VAE parameter types: %s", latent_sample.dtype)
|
|
||||||
if latent_sample.dtype != timestep_dtype:
|
|
||||||
logger.info("converting VAE sample dtype")
|
|
||||||
latent_sample = latent_sample.astype(timestep_dtype)
|
|
||||||
|
|
||||||
return self.wrapped(latent_sample=latent_sample, **kwargs)
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.wrapped, attr)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_pipeline(
|
def patch_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
|
@ -496,7 +412,7 @@ def patch_pipeline(
|
||||||
|
|
||||||
if hasattr(pipe, "vae_decoder"):
|
if hasattr(pipe, "vae_decoder"):
|
||||||
original_vae = pipe.vae_decoder
|
original_vae = pipe.vae_decoder
|
||||||
pipe.vae_decoder = VAEWrapper(server, original_vae)
|
pipe.vae_decoder = VAEWrapper(server, original_vae, decoder=True)
|
||||||
elif hasattr(pipe, "vae"):
|
elif hasattr(pipe, "vae"):
|
||||||
pass # TODO: current wrapper does not work with upscaling VAE
|
pass # TODO: current wrapper does not work with upscaling VAE
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
from diffusers import OnnxRuntimeModel
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from ...server import ServerContext
|
||||||
|
from .vae import set_vae_dtype
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UNetWrapper(object):
|
||||||
|
prompt_embeds: Optional[List[np.ndarray]] = None
|
||||||
|
prompt_index: int = 0
|
||||||
|
server: ServerContext
|
||||||
|
wrapped: OnnxRuntimeModel
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server: ServerContext,
|
||||||
|
wrapped: OnnxRuntimeModel,
|
||||||
|
):
|
||||||
|
self.server = server
|
||||||
|
self.wrapped = wrapped
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sample: np.ndarray = None,
|
||||||
|
timestep: np.ndarray = None,
|
||||||
|
encoder_hidden_states: np.ndarray = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
logger.trace(
|
||||||
|
"UNet parameter types: %s, %s, %s",
|
||||||
|
sample.dtype,
|
||||||
|
timestep.dtype,
|
||||||
|
encoder_hidden_states.dtype,
|
||||||
|
)
|
||||||
|
set_vae_dtype(timestep.dtype)
|
||||||
|
|
||||||
|
if self.prompt_embeds is not None:
|
||||||
|
step_index = self.prompt_index % len(self.prompt_embeds)
|
||||||
|
logger.trace("multiple prompt embeds found, using step: %s", step_index)
|
||||||
|
encoder_hidden_states = self.prompt_embeds[step_index]
|
||||||
|
self.prompt_index += 1
|
||||||
|
|
||||||
|
if sample.dtype != timestep.dtype:
|
||||||
|
logger.trace("converting UNet sample to timestep dtype")
|
||||||
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
||||||
|
if encoder_hidden_states.dtype != timestep.dtype:
|
||||||
|
logger.trace("converting UNet hidden states to timestep dtype")
|
||||||
|
encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype)
|
||||||
|
|
||||||
|
return self.wrapped(
|
||||||
|
sample=sample,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.wrapped, attr)
|
||||||
|
|
||||||
|
def set_prompts(self, prompt_embeds: List[np.ndarray]):
|
||||||
|
logger.debug(
|
||||||
|
"setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds]
|
||||||
|
)
|
||||||
|
self.prompt_embeds = prompt_embeds
|
||||||
|
self.prompt_index = 0
|
||||||
|
|
|
@ -0,0 +1,165 @@
|
||||||
|
from typing import Union
|
||||||
|
from diffusers.models.autoencoder_kl import AutoencoderKLOutput
|
||||||
|
from diffusers.models.vae import DiagonalGaussianDistribution, DecoderOutput
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...server import ServerContext
|
||||||
|
from diffusers import OnnxRuntimeModel
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
LATENT_CHANNELS = 4
|
||||||
|
SAMPLE_SIZE = 32
|
||||||
|
|
||||||
|
# TODO: does this need to change for fp16 modes?
|
||||||
|
timestep_dtype = np.float32
|
||||||
|
|
||||||
|
|
||||||
|
def set_vae_dtype(dtype):
|
||||||
|
global timestep_dtype
|
||||||
|
timestep_dtype = dtype
|
||||||
|
|
||||||
|
|
||||||
|
class VAEWrapper(object):
|
||||||
|
def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool):
|
||||||
|
self.server = server
|
||||||
|
self.wrapped = wrapped
|
||||||
|
self.decoder = decoder
|
||||||
|
|
||||||
|
self.tile_sample_min_size = SAMPLE_SIZE
|
||||||
|
self.tile_latent_min_size = int(SAMPLE_SIZE / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||||
|
self.tile_overlap_factor = 0.25
|
||||||
|
|
||||||
|
self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1)
|
||||||
|
self.post_quant_conv = nn.Conv2d(LATENT_CHANNELS, LATENT_CHANNELS, 1)
|
||||||
|
|
||||||
|
def __call__(self, latent_sample=None, **kwargs):
|
||||||
|
global timestep_dtype
|
||||||
|
|
||||||
|
logger.trace("VAE %s parameter types: %s", ("decoder" if self.decoder else "encoder"), latent_sample.dtype)
|
||||||
|
if latent_sample.dtype != timestep_dtype:
|
||||||
|
logger.info("converting VAE sample dtype")
|
||||||
|
latent_sample = latent_sample.astype(timestep_dtype)
|
||||||
|
|
||||||
|
if self.decoder:
|
||||||
|
return self.tiled_decode(latent_sample, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.tiled_encode(latent_sample, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.wrapped, attr)
|
||||||
|
|
||||||
|
def blend_v(self, a, b, blend_extent):
|
||||||
|
for y in range(min(a.shape[2], b.shape[2], blend_extent)):
|
||||||
|
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def blend_h(self, a, b, blend_extent):
|
||||||
|
for x in range(min(a.shape[3], b.shape[3], blend_extent)):
|
||||||
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||||
|
return b
|
||||||
|
|
||||||
|
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||||
|
r"""Encode a batch of images using a tiled encoder.
|
||||||
|
Args:
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||||
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
|
||||||
|
different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
|
||||||
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||||
|
look of the output, but they should be much less noticeable.
|
||||||
|
x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
x = torch.from_numpy(x)
|
||||||
|
|
||||||
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||||
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||||
|
row_limit = self.tile_latent_min_size - blend_extent
|
||||||
|
|
||||||
|
# Split the image into 512x512 tiles and encode them separately.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, x.shape[2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, x.shape[3], overlap_size):
|
||||||
|
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||||
|
tile = self(tile)
|
||||||
|
tile = self.quant_conv(tile)
|
||||||
|
row.append(tile)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
moments = torch.cat(result_rows, dim=2)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
posterior = posterior.numpy()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (posterior,)
|
||||||
|
|
||||||
|
return AutoencoderKLOutput(latent_dist=posterior)
|
||||||
|
|
||||||
|
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||||
|
r"""Decode a batch of images using a tiled decoder.
|
||||||
|
Args:
|
||||||
|
When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
|
||||||
|
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
|
||||||
|
different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
|
||||||
|
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||||
|
look of the output, but they should be much less noticeable.
|
||||||
|
z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
|
||||||
|
`True`):
|
||||||
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||||
|
"""
|
||||||
|
if isinstance(z, np.ndarray):
|
||||||
|
z = torch.from_numpy(z)
|
||||||
|
|
||||||
|
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||||
|
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||||
|
row_limit = self.tile_sample_min_size - blend_extent
|
||||||
|
|
||||||
|
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||||
|
# The tiles have an overlap to avoid seams between tiles.
|
||||||
|
rows = []
|
||||||
|
for i in range(0, z.shape[2], overlap_size):
|
||||||
|
row = []
|
||||||
|
for j in range(0, z.shape[3], overlap_size):
|
||||||
|
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||||
|
tile = self.post_quant_conv(tile)
|
||||||
|
decoded = self(tile)
|
||||||
|
row.append(decoded)
|
||||||
|
rows.append(row)
|
||||||
|
result_rows = []
|
||||||
|
for i, row in enumerate(rows):
|
||||||
|
result_row = []
|
||||||
|
for j, tile in enumerate(row):
|
||||||
|
# blend the above tile and the left tile
|
||||||
|
# to the current tile and add the current tile to the result row
|
||||||
|
if i > 0:
|
||||||
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||||
|
if j > 0:
|
||||||
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||||
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||||
|
result_rows.append(torch.cat(result_row, dim=3))
|
||||||
|
|
||||||
|
dec = torch.cat(result_rows, dim=2)
|
||||||
|
dec = dec.numpy()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (dec,)
|
||||||
|
|
||||||
|
return DecoderOutput(sample=dec)
|
|
@ -283,7 +283,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||||
f" {negative_prompt_embeds.shape}."
|
f" {negative_prompt_embeds.shape}."
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
|
def get_views(self, panorama_height, panorama_width, window_size=32, stride=8):
|
||||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||||
panorama_height /= 8
|
panorama_height /= 8
|
||||||
panorama_width /= 8
|
panorama_width /= 8
|
||||||
|
|
|
@ -11,15 +11,15 @@ from ..params import ImageParams, Size
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
latent_channels = 4
|
LATENT_CHANNELS = 4
|
||||||
latent_factor = 8
|
LATENT_FACTOR = 8
|
||||||
MAX_TOKENS_PER_GROUP = 77
|
MAX_TOKENS_PER_GROUP = 77
|
||||||
|
|
||||||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||||
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):(-?[\.|\d]+)\>")
|
INVERSION_TOKEN = compile(r"\<inversion:([-\w]+):(-?[\.|\d]+)\>")
|
||||||
LORA_TOKEN = compile(r"\<lora:([-\w]+):(-?[\.|\d]+)\>")
|
LORA_TOKEN = compile(r"\<lora:([-\w]+):(-?[\.|\d]+)\>")
|
||||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
ALTERNATIVE_RANGE = compile(r"\(([\w\|]+)\)")
|
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||||
|
|
||||||
|
|
||||||
def expand_interval_ranges(prompt: str) -> str:
|
def expand_interval_ranges(prompt: str) -> str:
|
||||||
|
@ -253,9 +253,9 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
latents_shape = (
|
latents_shape = (
|
||||||
batch,
|
batch,
|
||||||
latent_channels,
|
LATENT_CHANNELS,
|
||||||
size.height // latent_factor,
|
size.height // LATENT_FACTOR,
|
||||||
size.width // latent_factor,
|
size.width // LATENT_FACTOR,
|
||||||
)
|
)
|
||||||
rng = np.random.default_rng(seed)
|
rng = np.random.default_rng(seed)
|
||||||
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
image_latents = rng.standard_normal(latents_shape).astype(np.float32)
|
||||||
|
@ -266,9 +266,9 @@ def get_tile_latents(
|
||||||
full_latents: np.ndarray, dims: Tuple[int, int, int]
|
full_latents: np.ndarray, dims: Tuple[int, int, int]
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
x, y, tile = dims
|
x, y, tile = dims
|
||||||
t = tile // latent_factor
|
t = tile // LATENT_FACTOR
|
||||||
x = x // latent_factor
|
x = x // LATENT_FACTOR
|
||||||
y = y // latent_factor
|
y = y // LATENT_FACTOR
|
||||||
xt = x + t
|
xt = x + t
|
||||||
yt = y + t
|
yt = y + t
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
"settings": {
|
"settings": {
|
||||||
"cSpell.words": [
|
"cSpell.words": [
|
||||||
"astype",
|
"astype",
|
||||||
|
"Autoencoder",
|
||||||
"basicsr",
|
"basicsr",
|
||||||
"Civitai",
|
"Civitai",
|
||||||
"ckpt",
|
"ckpt",
|
||||||
|
|
Loading…
Reference in New Issue