From 6f283c5c0245c1e9df46aa8bdd50032404ccf998 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 10:16:43 -0500 Subject: [PATCH] proxy nets that need fp16 conversion --- api/onnx_web/diffusers/load.py | 65 +++++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 2408885c..189b66dd 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -2,6 +2,7 @@ from logging import getLogger from os import path from typing import Any, List, Optional, Tuple +import torch import numpy as np from diffusers import ( DDIMScheduler, @@ -117,7 +118,11 @@ def load_pipeline( inversions: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None, ): + inversions = inversions or [] loras = loras or [] + + torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32 + logger.debug("using Torch dtype %s for pipeline", torch_dtype) pipe_key = ( pipeline.__name__, model, @@ -144,6 +149,7 @@ def load_pipeline( provider=device.ort_provider(), sess_options=device.sess_options(), subfolder="scheduler", + torch_dtype=torch_dtype, ) if device is not None and hasattr(scheduler, "to"): @@ -170,6 +176,7 @@ def load_pipeline( provider=device.ort_provider(), sess_options=device.sess_options(), subfolder="scheduler", + torch_dtype=torch_dtype, ) } @@ -186,6 +193,7 @@ def load_pipeline( tokenizer = CLIPTokenizer.from_pretrained( model, subfolder="tokenizer", + torch_dtype=torch_dtype, ) text_encoder, tokenizer = blend_textual_inversions( server, @@ -275,6 +283,7 @@ def load_pipeline( sess_options=device.sess_options(), revision="onnx", safety_checker=None, + torch_dtype=torch_dtype, **components, ) @@ -338,6 +347,46 @@ def optimize_pipeline( logger.warning("error while enabling memory efficient attention: %s", e) +timestep_dtype = None + +class UNetWrapper(object): + def __init__(self, wrapped): + self.wrapped = wrapped + + def __call__(self, sample=None, timestep=None, encoder_hidden_states=None): + global timestep_dtype + timestep_dtype = timestep.dtype + + logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) + if sample.dtype != timestep.dtype: + logger.warning("converting UNet sample dtype") + sample = sample.astype(timestep.dtype) + + return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states) + + def __getattr__(self, attr): + return getattr(self.wrapped, attr) + + +class VAEWrapper(object): + def __init__(self, wrapped): + self.wrapped = wrapped + + def __call__(self, latent_sample=None): + global timestep_dtype + + logger.info("VAE parameter types: %s", latent_sample.dtype) + if latent_sample.dtype != timestep_dtype: + logger.warning("converting VAE sample dtype") + sample = sample.astype(timestep_dtype) + + return self.wrapped(latent_sample=latent_sample) + + def __getattr__(self, attr): + return getattr(self.wrapped, attr) + + + def patch_pipeline( server: ServerContext, pipe: StableDiffusionPipeline, @@ -346,16 +395,8 @@ def patch_pipeline( logger.debug("patching SD pipeline") pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) - original_unet = pipe.unet.__call__ - original_vae = pipe.vae_decoder.__call__ + original_unet = pipe.unet + original_vae = pipe.vae_decoder - def unet_call(sample=None, timestep=None, encoder_hidden_states=None): - logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) - return original_unet(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states) - - def vae_call(latent_sample=None): - logger.info("VAE parameter types: %s", latent_sample.dtype) - return original_vae(latent_sample=latent_sample) - - pipe.unet.__call__ = unet_call - pipe.vae_decoder.__call__ = vae_call + pipe.unet = UNetWrapper(original_unet) + pipe.vae_decoder = VAEWrapper(original_vae)