proxy nets that need fp16 conversion
This commit is contained in:
parent
bbd4c0fd72
commit
6f283c5c02
|
@ -2,6 +2,7 @@ from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any, List, Optional, Tuple
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
|
@ -117,7 +118,11 @@ def load_pipeline(
|
||||||
inversions: Optional[List[Tuple[str, float]]] = None,
|
inversions: Optional[List[Tuple[str, float]]] = None,
|
||||||
loras: Optional[List[Tuple[str, float]]] = None,
|
loras: Optional[List[Tuple[str, float]]] = None,
|
||||||
):
|
):
|
||||||
|
inversions = inversions or []
|
||||||
loras = loras or []
|
loras = loras or []
|
||||||
|
|
||||||
|
torch_dtype = torch.float16 if "torch-fp16" in server.optimizations else torch.float32
|
||||||
|
logger.debug("using Torch dtype %s for pipeline", torch_dtype)
|
||||||
pipe_key = (
|
pipe_key = (
|
||||||
pipeline.__name__,
|
pipeline.__name__,
|
||||||
model,
|
model,
|
||||||
|
@ -144,6 +149,7 @@ def load_pipeline(
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider(),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
subfolder="scheduler",
|
subfolder="scheduler",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(scheduler, "to"):
|
if device is not None and hasattr(scheduler, "to"):
|
||||||
|
@ -170,6 +176,7 @@ def load_pipeline(
|
||||||
provider=device.ort_provider(),
|
provider=device.ort_provider(),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
subfolder="scheduler",
|
subfolder="scheduler",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -186,6 +193,7 @@ def load_pipeline(
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
model,
|
model,
|
||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
text_encoder, tokenizer = blend_textual_inversions(
|
text_encoder, tokenizer = blend_textual_inversions(
|
||||||
server,
|
server,
|
||||||
|
@ -275,6 +283,7 @@ def load_pipeline(
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
revision="onnx",
|
revision="onnx",
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
**components,
|
**components,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -338,6 +347,46 @@ def optimize_pipeline(
|
||||||
logger.warning("error while enabling memory efficient attention: %s", e)
|
logger.warning("error while enabling memory efficient attention: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
timestep_dtype = None
|
||||||
|
|
||||||
|
class UNetWrapper(object):
|
||||||
|
def __init__(self, wrapped):
|
||||||
|
self.wrapped = wrapped
|
||||||
|
|
||||||
|
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
|
||||||
|
global timestep_dtype
|
||||||
|
timestep_dtype = timestep.dtype
|
||||||
|
|
||||||
|
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||||
|
if sample.dtype != timestep.dtype:
|
||||||
|
logger.warning("converting UNet sample dtype")
|
||||||
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
||||||
|
return self.wrapped(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.wrapped, attr)
|
||||||
|
|
||||||
|
|
||||||
|
class VAEWrapper(object):
|
||||||
|
def __init__(self, wrapped):
|
||||||
|
self.wrapped = wrapped
|
||||||
|
|
||||||
|
def __call__(self, latent_sample=None):
|
||||||
|
global timestep_dtype
|
||||||
|
|
||||||
|
logger.info("VAE parameter types: %s", latent_sample.dtype)
|
||||||
|
if latent_sample.dtype != timestep_dtype:
|
||||||
|
logger.warning("converting VAE sample dtype")
|
||||||
|
sample = sample.astype(timestep_dtype)
|
||||||
|
|
||||||
|
return self.wrapped(latent_sample=latent_sample)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
return getattr(self.wrapped, attr)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def patch_pipeline(
|
def patch_pipeline(
|
||||||
server: ServerContext,
|
server: ServerContext,
|
||||||
pipe: StableDiffusionPipeline,
|
pipe: StableDiffusionPipeline,
|
||||||
|
@ -346,16 +395,8 @@ def patch_pipeline(
|
||||||
logger.debug("patching SD pipeline")
|
logger.debug("patching SD pipeline")
|
||||||
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
pipe._encode_prompt = expand_prompt.__get__(pipe, pipeline)
|
||||||
|
|
||||||
original_unet = pipe.unet.__call__
|
original_unet = pipe.unet
|
||||||
original_vae = pipe.vae_decoder.__call__
|
original_vae = pipe.vae_decoder
|
||||||
|
|
||||||
def unet_call(sample=None, timestep=None, encoder_hidden_states=None):
|
pipe.unet = UNetWrapper(original_unet)
|
||||||
logger.info("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
pipe.vae_decoder = VAEWrapper(original_vae)
|
||||||
return original_unet(sample=sample, timestep=timestep, encoder_hidden_states=encoder_hidden_states)
|
|
||||||
|
|
||||||
def vae_call(latent_sample=None):
|
|
||||||
logger.info("VAE parameter types: %s", latent_sample.dtype)
|
|
||||||
return original_vae(latent_sample=latent_sample)
|
|
||||||
|
|
||||||
pipe.unet.__call__ = unet_call
|
|
||||||
pipe.vae_decoder.__call__ = vae_call
|
|
||||||
|
|
Loading…
Reference in New Issue