diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 7e2a7adc..c570179b 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -15,6 +15,8 @@ from ..models.meta import NetworkModel from ..params import DeviceParams from ..server import ServerContext from ..utils import run_gc +from .patches.unet import UNetWrapper +from .patches.vae import VAEWrapper from .pipelines.controlnet import OnnxStableDiffusionControlNetPipeline from .pipelines.lpw import OnnxStableDiffusionLongPromptWeightingPipeline from .pipelines.panorama import OnnxStableDiffusionPanoramaPipeline @@ -397,92 +399,6 @@ def optimize_pipeline( logger.warning("error while enabling memory efficient attention: %s", e) -# TODO: does this need to change for fp16 modes? -timestep_dtype = np.float32 - - -class UNetWrapper(object): - prompt_embeds: Optional[List[np.ndarray]] = None - prompt_index: int = 0 - server: ServerContext - wrapped: OnnxRuntimeModel - - def __init__( - self, - server: ServerContext, - wrapped: OnnxRuntimeModel, - ): - self.server = server - self.wrapped = wrapped - - def __call__( - self, - sample: np.ndarray = None, - timestep: np.ndarray = None, - encoder_hidden_states: np.ndarray = None, - **kwargs, - ): - global timestep_dtype - timestep_dtype = timestep.dtype - - logger.trace( - "UNet parameter types: %s, %s, %s", - sample.dtype, - timestep.dtype, - encoder_hidden_states.dtype, - ) - - if self.prompt_embeds is not None: - step_index = self.prompt_index % len(self.prompt_embeds) - logger.trace("multiple prompt embeds found, using step: %s", step_index) - encoder_hidden_states = self.prompt_embeds[step_index] - self.prompt_index += 1 - - if sample.dtype != timestep.dtype: - logger.trace("converting UNet sample to timestep dtype") - sample = sample.astype(timestep.dtype) - - if encoder_hidden_states.dtype != timestep.dtype: - logger.trace("converting UNet hidden states to timestep dtype") - encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) - - return self.wrapped( - sample=sample, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - **kwargs, - ) - - def __getattr__(self, attr): - return getattr(self.wrapped, attr) - - def set_prompts(self, prompt_embeds: List[np.ndarray]): - logger.debug( - "setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds] - ) - self.prompt_embeds = prompt_embeds - self.prompt_index = 0 - - -class VAEWrapper(object): - def __init__(self, server, wrapped): - self.server = server - self.wrapped = wrapped - - def __call__(self, latent_sample=None, **kwargs): - global timestep_dtype - - logger.trace("VAE parameter types: %s", latent_sample.dtype) - if latent_sample.dtype != timestep_dtype: - logger.info("converting VAE sample dtype") - latent_sample = latent_sample.astype(timestep_dtype) - - return self.wrapped(latent_sample=latent_sample, **kwargs) - - def __getattr__(self, attr): - return getattr(self.wrapped, attr) - - def patch_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, @@ -496,7 +412,7 @@ def patch_pipeline( if hasattr(pipe, "vae_decoder"): original_vae = pipe.vae_decoder - pipe.vae_decoder = VAEWrapper(server, original_vae) + pipe.vae_decoder = VAEWrapper(server, original_vae, decoder=True) elif hasattr(pipe, "vae"): pass # TODO: current wrapper does not work with upscaling VAE else: diff --git a/api/onnx_web/diffusers/patches/unet.py b/api/onnx_web/diffusers/patches/unet.py new file mode 100644 index 00000000..0678b91a --- /dev/null +++ b/api/onnx_web/diffusers/patches/unet.py @@ -0,0 +1,72 @@ +from diffusers import OnnxRuntimeModel +from logging import getLogger +from typing import List, Optional + +from ...server import ServerContext +from .vae import set_vae_dtype + +import numpy as np + +logger = getLogger(__name__) + + +class UNetWrapper(object): + prompt_embeds: Optional[List[np.ndarray]] = None + prompt_index: int = 0 + server: ServerContext + wrapped: OnnxRuntimeModel + + def __init__( + self, + server: ServerContext, + wrapped: OnnxRuntimeModel, + ): + self.server = server + self.wrapped = wrapped + + def __call__( + self, + sample: np.ndarray = None, + timestep: np.ndarray = None, + encoder_hidden_states: np.ndarray = None, + **kwargs, + ): + logger.trace( + "UNet parameter types: %s, %s, %s", + sample.dtype, + timestep.dtype, + encoder_hidden_states.dtype, + ) + set_vae_dtype(timestep.dtype) + + if self.prompt_embeds is not None: + step_index = self.prompt_index % len(self.prompt_embeds) + logger.trace("multiple prompt embeds found, using step: %s", step_index) + encoder_hidden_states = self.prompt_embeds[step_index] + self.prompt_index += 1 + + if sample.dtype != timestep.dtype: + logger.trace("converting UNet sample to timestep dtype") + sample = sample.astype(timestep.dtype) + + if encoder_hidden_states.dtype != timestep.dtype: + logger.trace("converting UNet hidden states to timestep dtype") + encoder_hidden_states = encoder_hidden_states.astype(timestep.dtype) + + return self.wrapped( + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + **kwargs, + ) + + def __getattr__(self, attr): + return getattr(self.wrapped, attr) + + def set_prompts(self, prompt_embeds: List[np.ndarray]): + logger.debug( + "setting prompt embeds for UNet: %s", [p.shape for p in prompt_embeds] + ) + self.prompt_embeds = prompt_embeds + self.prompt_index = 0 + diff --git a/api/onnx_web/diffusers/patches/vae.py b/api/onnx_web/diffusers/patches/vae.py new file mode 100644 index 00000000..1e56e3da --- /dev/null +++ b/api/onnx_web/diffusers/patches/vae.py @@ -0,0 +1,165 @@ +from typing import Union +from diffusers.models.autoencoder_kl import AutoencoderKLOutput +from diffusers.models.vae import DiagonalGaussianDistribution, DecoderOutput +from logging import getLogger + +import numpy as np +import torch +import torch.nn as nn + +from ...server import ServerContext +from diffusers import OnnxRuntimeModel + +logger = getLogger(__name__) + +LATENT_CHANNELS = 4 +SAMPLE_SIZE = 32 + +# TODO: does this need to change for fp16 modes? +timestep_dtype = np.float32 + + +def set_vae_dtype(dtype): + global timestep_dtype + timestep_dtype = dtype + + +class VAEWrapper(object): + def __init__(self, server: ServerContext, wrapped: OnnxRuntimeModel, decoder: bool): + self.server = server + self.wrapped = wrapped + self.decoder = decoder + + self.tile_sample_min_size = SAMPLE_SIZE + self.tile_latent_min_size = int(SAMPLE_SIZE / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + self.quant_conv = nn.Conv2d(2 * LATENT_CHANNELS, 2 * LATENT_CHANNELS, 1) + self.post_quant_conv = nn.Conv2d(LATENT_CHANNELS, LATENT_CHANNELS, 1) + + def __call__(self, latent_sample=None, **kwargs): + global timestep_dtype + + logger.trace("VAE %s parameter types: %s", ("decoder" if self.decoder else "encoder"), latent_sample.dtype) + if latent_sample.dtype != timestep_dtype: + logger.info("converting VAE sample dtype") + latent_sample = latent_sample.astype(timestep_dtype) + + if self.decoder: + return self.tiled_decode(latent_sample, **kwargs) + else: + return self.tiled_encode(latent_sample, **kwargs) + + def __getattr__(self, attr): + return getattr(self.wrapped, attr) + + def blend_v(self, a, b, blend_extent): + for y in range(min(a.shape[2], b.shape[2], blend_extent)): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a, b, blend_extent): + for x in range(min(a.shape[3], b.shape[3], blend_extent)): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + Args: + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: + different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + look of the output, but they should be much less noticeable. + x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. + """ + if isinstance(x, np.ndarray): + x = torch.from_numpy(x) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + posterior = posterior.numpy() + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r"""Decode a batch of images using a tiled decoder. + Args: + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: + different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + look of the output, but they should be much less noticeable. + z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to + `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + if isinstance(z, np.ndarray): + z = torch.from_numpy(z) + + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + dec = dec.numpy() + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) \ No newline at end of file diff --git a/api/onnx_web/diffusers/pipelines/panorama.py b/api/onnx_web/diffusers/pipelines/panorama.py index 2578021e..deceb624 100644 --- a/api/onnx_web/diffusers/pipelines/panorama.py +++ b/api/onnx_web/diffusers/pipelines/panorama.py @@ -283,7 +283,7 @@ class OnnxStableDiffusionPanoramaPipeline(DiffusionPipeline): f" {negative_prompt_embeds.shape}." ) - def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): + def get_views(self, panorama_height, panorama_width, window_size=32, stride=8): # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) panorama_height /= 8 panorama_width /= 8 diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 4f82b6cc..ea1df97c 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -11,15 +11,15 @@ from ..params import ImageParams, Size logger = getLogger(__name__) -latent_channels = 4 -latent_factor = 8 +LATENT_CHANNELS = 4 +LATENT_FACTOR = 8 MAX_TOKENS_PER_GROUP = 77 CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\") LORA_TOKEN = compile(r"\") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") -ALTERNATIVE_RANGE = compile(r"\(([\w\|]+)\)") +ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") def expand_interval_ranges(prompt: str) -> str: @@ -253,9 +253,9 @@ def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: """ latents_shape = ( batch, - latent_channels, - size.height // latent_factor, - size.width // latent_factor, + LATENT_CHANNELS, + size.height // LATENT_FACTOR, + size.width // LATENT_FACTOR, ) rng = np.random.default_rng(seed) image_latents = rng.standard_normal(latents_shape).astype(np.float32) @@ -266,9 +266,9 @@ def get_tile_latents( full_latents: np.ndarray, dims: Tuple[int, int, int] ) -> np.ndarray: x, y, tile = dims - t = tile // latent_factor - x = x // latent_factor - y = y // latent_factor + t = tile // LATENT_FACTOR + x = x // LATENT_FACTOR + y = y // LATENT_FACTOR xt = x + t yt = y + t diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index b0c0abb7..65717d40 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -14,6 +14,7 @@ "settings": { "cSpell.words": [ "astype", + "Autoencoder", "basicsr", "Civitai", "ckpt",