1
0
Fork 0

proxy nets that need fp16 conversion

This commit is contained in:
Sean Sube 2023-03-19 10:16:43 -05:00
parent bbd4c0fd72
commit 6f283c5c02
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 53 additions and 12 deletions

View File

@ -2,6 +2,7 @@ from logging import getLogger
from os import path from os import path
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch
import numpy as np import numpy as np
from diffusers import ( from diffusers import (
DDIMScheduler, DDIMScheduler,
@ -117,7 +118,11 @@ def load_pipeline(
inversions: Optional[List[Tuple[str, float]]] = None, inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None, loras: Optional[List[Tuple[str, float]]] = None,
): ):
inversions = inversions or []
loras = loras or [] loras = loras or []
torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
pipe_key = ( pipe_key = (
pipeline.__name__, pipeline.__name__,
model, model,
@ -144,6 +149,7 @@ def load_pipeline(
provider=device.ort_provider(), provider=device.ort_provider(),
sess_options=device.sess_options(), sess_options=device.sess_options(),
subfolder="scheduler", subfolder="scheduler",
torch_dtype=torch_dtype,
) )
if device is not None and hasattr(scheduler, "to"): if device is not None and hasattr(scheduler, "to"):
@ -170,6 +176,7 @@ def load_pipeline(
provider=device.ort_provider(), provider=device.ort_provider(),
sess_options=device.sess_options(), sess_options=device.sess_options(),
subfolder="scheduler", subfolder="scheduler",
torch_dtype=torch_dtype,
) )
} }
@ -186,6 +193,7 @@ def load_pipeline(
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
model, model,
subfolder="tokenizer", subfolder="tokenizer",
torch_dtype=torch_dtype,
) )
text_encoder, tokenizer = blend_textual_inversions( text_encoder, tokenizer = blend_textual_inversions(
server, server,
@ -275,6 +283,7 @@ def load_pipeline(
sess_options=device.sess_options(), sess_options=device.sess_options(),
revision="onnx", revision="onnx",
safety_checker=None, safety_checker=None,
torch_dtype=torch_dtype,
**components, **components,
) )
@ -338,6 +347,46 @@ def optimize_pipeline(
logger.warning("error while enabling memory efficient attention: %s", e) logger.warning("error while enabling memory efficient attention: %s", e)
timestep_dtype = None
class UNetWrapper(object):
def __init__(self, wrapped):
self.wrapped = wrapped
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
global timestep_dtype
timestep_dtype = timestep.dtype
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if sample.dtype != timestep.dtype:
logger.warning("converting UNet sample dtype")
sample = sample.astype(timestep.dtype)
return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
class VAEWrapper(object):
def __init__(self, wrapped):
self.wrapped = wrapped
def __call__(self, latent_sample=None):
global timestep_dtype
logger.info("VAE parameter types: %s", latent_sample.dtype)
if latent_sample.dtype != timestep_dtype:
logger.warning("converting VAE sample dtype")
sample = sample.astype(timestep_dtype)
return self.wrapped(latent_sample=latent_sample)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def patch_pipeline( def patch_pipeline(
server: ServerContext, server: ServerContext,
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
@ -346,16 +395,8 @@ def patch_pipeline(
logger.debug("patching SD pipeline") logger.debug("patching SD pipeline")
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline) pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
original_unet = pipe.unet.__call__ original_unet = pipe.unet
original_vae = pipe.vae_decoder.__call__ original_vae = pipe.vae_decoder
def unet_call(sample=None, timestep=None, encoder_hidden_states=None): pipe.unet = UNetWrapper(original_unet)
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) pipe.vae_decoder = VAEWrapper(original_vae)
return original_unet(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
def vae_call(latent_sample=None):
logger.info("VAE parameter types: %s", latent_sample.dtype)
return original_vae(latent_sample=latent_sample)
pipe.unet.__call__ = unet_call
pipe.vae_decoder.__call__ = vae_call