proxy nets that need fp16 conversion
This commit is contained in:
parent
bbd4c0fd72
commit
6f283c5c02
|
@ -2,6 +2,7 @@ from logging import getLogger
|
|||
from os import path
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
|
@ -117,7 +118,11 @@ def load_pipeline(
|
|||
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||
loras: Optional[List[Tuple[str, float]]] = None,
|
||||
):
|
||||
inversions = inversions or []
|
||||
loras = loras or []
|
||||
|
||||
torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
||||
pipe_key = (
|
||||
pipeline.__name__,
|
||||
model,
|
||||
|
@ -144,6 +149,7 @@ def load_pipeline(
|
|||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
|
||||
if device is not None and hasattr(scheduler, "to"):
|
||||
|
@ -170,6 +176,7 @@ def load_pipeline(
|
|||
provider=device.ort_provider(),
|
||||
sess_options=device.sess_options(),
|
||||
subfolder="scheduler",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -186,6 +193,7 @@ def load_pipeline(
|
|||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
model,
|
||||
subfolder="tokenizer",
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
text_encoder, tokenizer = blend_textual_inversions(
|
||||
server,
|
||||
|
@ -275,6 +283,7 @@ def load_pipeline(
|
|||
sess_options=device.sess_options(),
|
||||
revision="onnx",
|
||||
safety_checker=None,
|
||||
torch_dtype=torch_dtype,
|
||||
**components,
|
||||
)
|
||||
|
||||
|
@ -338,6 +347,46 @@ def optimize_pipeline(
|
|||
logger.warning("error while enabling memory efficient attention: %s", e)
|
||||
|
||||
|
||||
timestep_dtype = None
|
||||
|
||||
class UNetWrapper(object):
|
||||
def __init__(self, wrapped):
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
|
||||
global timestep_dtype
|
||||
timestep_dtype = timestep.dtype
|
||||
|
||||
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.warning("converting UNet sample dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
||||
return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
|
||||
class VAEWrapper(object):
|
||||
def __init__(self, wrapped):
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(self, latent_sample=None):
|
||||
global timestep_dtype
|
||||
|
||||
logger.info("VAE parameter types: %s", latent_sample.dtype)
|
||||
if latent_sample.dtype != timestep_dtype:
|
||||
logger.warning("converting VAE sample dtype")
|
||||
sample = sample.astype(timestep_dtype)
|
||||
|
||||
return self.wrapped(latent_sample=latent_sample)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
|
||||
|
||||
def patch_pipeline(
|
||||
server: ServerContext,
|
||||
pipe: StableDiffusionPipeline,
|
||||
|
@ -346,16 +395,8 @@ def patch_pipeline(
|
|||
logger.debug("patching SD pipeline")
|
||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||
|
||||
original_unet = pipe.unet.__call__
|
||||
original_vae = pipe.vae_decoder.__call__
|
||||
original_unet = pipe.unet
|
||||
original_vae = pipe.vae_decoder
|
||||
|
||||
def unet_call(sample=None, timestep=None, encoder_hidden_states=None):
|
||||
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||
return original_unet(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
def vae_call(latent_sample=None):
|
||||
logger.info("VAE parameter types: %s", latent_sample.dtype)
|
||||
return original_vae(latent_sample=latent_sample)
|
||||
|
||||
pipe.unet.__call__ = unet_call
|
||||
pipe.vae_decoder.__call__ = vae_call
|
||||
pipe.unet = UNetWrapper(original_unet)
|
||||
pipe.vae_decoder = VAEWrapper(original_vae)
|
||||
|
|
Loading…
Reference in New Issue