1
0
Fork 0

proxy nets that need fp16 conversion

This commit is contained in:
Sean Sube 2023-03-19 10:16:43 -05:00
parent bbd4c0fd72
commit 6f283c5c02
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 53 additions and 12 deletions

View File

@ -2,6 +2,7 @@ from logging import getLogger
from os import path
from typing import Any, List, Optional, Tuple
import torch
import numpy as np
from diffusers import (
DDIMScheduler,
@ -117,7 +118,11 @@ def load_pipeline(
inversions: Optional[List[Tuple[str, float]]] = None,
loras: Optional[List[Tuple[str, float]]] = None,
):
inversions = inversions or []
loras = loras or []
torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
pipe_key = (
pipeline.__name__,
model,
@ -144,6 +149,7 @@ def load_pipeline(
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
)
if device is not None and hasattr(scheduler, "to"):
@ -170,6 +176,7 @@ def load_pipeline(
provider=device.ort_provider(),
sess_options=device.sess_options(),
subfolder="scheduler",
torch_dtype=torch_dtype,
)
}
@ -186,6 +193,7 @@ def load_pipeline(
tokenizer = CLIPTokenizer.from_pretrained(
model,
subfolder="tokenizer",
torch_dtype=torch_dtype,
)
text_encoder, tokenizer = blend_textual_inversions(
server,
@ -275,6 +283,7 @@ def load_pipeline(
sess_options=device.sess_options(),
revision="onnx",
safety_checker=None,
torch_dtype=torch_dtype,
**components,
)
@ -338,6 +347,46 @@ def optimize_pipeline(
logger.warning("error while enabling memory efficient attention: %s", e)
timestep_dtype = None
class UNetWrapper(object):
def __init__(self, wrapped):
self.wrapped = wrapped
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
global timestep_dtype
timestep_dtype = timestep.dtype
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if sample.dtype != timestep.dtype:
logger.warning("converting UNet sample dtype")
sample = sample.astype(timestep.dtype)
return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
class VAEWrapper(object):
def __init__(self, wrapped):
self.wrapped = wrapped
def __call__(self, latent_sample=None):
global timestep_dtype
logger.info("VAE parameter types: %s", latent_sample.dtype)
if latent_sample.dtype != timestep_dtype:
logger.warning("converting VAE sample dtype")
sample = sample.astype(timestep_dtype)
return self.wrapped(latent_sample=latent_sample)
def __getattr__(self, attr):
return getattr(self.wrapped, attr)
def patch_pipeline(
server: ServerContext,
pipe: StableDiffusionPipeline,
@ -346,16 +395,8 @@ def patch_pipeline(
logger.debug("patching SD pipeline")
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
original_unet = pipe.unet.__call__
original_vae = pipe.vae_decoder.__call__
original_unet = pipe.unet
original_vae = pipe.vae_decoder
def unet_call(sample=None, timestep=None, encoder_hidden_states=None):
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
return original_unet(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
def vae_call(latent_sample=None):
logger.info("VAE parameter types: %s", latent_sample.dtype)
return original_vae(latent_sample=latent_sample)
pipe.unet.__call__ = unet_call
pipe.vae_decoder.__call__ = vae_call
pipe.unet = UNetWrapper(original_unet)
pipe.vae_decoder = VAEWrapper(original_vae)