1
0
Fork 0

feat(api): add flag for ORT float16 optimizations

This commit is contained in:
Sean Sube 2023-03-19 11:59:35 -05:00
parent e4b59f0d9a
commit 1c631c28d3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 21 additions and 11 deletions

View File

@ -353,7 +353,8 @@ timestep_dtype = None
class UNetWrapper(object): class UNetWrapper(object):
def __init__(self, wrapped): def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped self.wrapped = wrapped
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None): def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
@ -361,8 +362,12 @@ class UNetWrapper(object):
timestep_dtype = timestep.dtype timestep_dtype = timestep.dtype
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if sample.dtype != timestep.dtype: if "onnx-fp16" in self.server.optimizations:
logger.info("converting UNet sample dtype") logger.info("converting UNet sample to ONNX fp16")
sample = sample.astype(np.float16)
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
elif sample.dtype != timestep.dtype:
logger.info("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype) sample = sample.astype(timestep.dtype)
return self.wrapped( return self.wrapped(
@ -376,7 +381,8 @@ class UNetWrapper(object):
class VAEWrapper(object): class VAEWrapper(object):
def __init__(self, wrapped): def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped self.wrapped = wrapped
def __call__(self, latent_sample=None): def __call__(self, latent_sample=None):
@ -404,5 +410,5 @@ def patch_pipeline(
original_unet = pipe.unet original_unet = pipe.unet
original_vae = pipe.vae_decoder original_vae = pipe.vae_decoder
pipe.unet = UNetWrapper(original_unet) pipe.unet = UNetWrapper(server, original_unet)
pipe.vae_decoder = VAEWrapper(original_vae) pipe.vae_decoder = VAEWrapper(server, original_vae)

View File

@ -99,8 +99,12 @@ Others:
- not available for ONNX pipelines (most of them) - not available for ONNX pipelines (most of them)
- https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches - https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches
- `onnx-*` - `onnx-*`
- `onnx-low-memory` - `onnx-deterministic-compute`
- disable ONNX features that allocate more memory than is strictly required or keep memory after use - enable ONNX deterministic compute
- `onnx-fp16`
- force 16-bit floating point values when running pipelines
- use with https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/stable_diffusion#optimize-onnx-pipeline
and the `--float16` flag
- `onnx-graph-*` - `onnx-graph-*`
- `onnx-graph-disable` - `onnx-graph-disable`
- disable all ONNX graph optimizations - disable all ONNX graph optimizations
@ -108,11 +112,11 @@ Others:
- enable basic ONNX graph optimizations - enable basic ONNX graph optimizations
- `onnx-graph-all` - `onnx-graph-all`
- enable all ONNX graph optimizations - enable all ONNX graph optimizations
- `onnx-deterministic-compute` - `onnx-low-memory`
- enable ONNX deterministic compute - disable ONNX features that allocate more memory than is strictly required or keep memory after use
- `torch-*` - `torch-*`
- `torch-fp16` - `torch-fp16`
- use 16-bit floating point values when loading and converting pipelines - use 16-bit floating point values when converting and running pipelines
- applies during conversion as well - applies during conversion as well
- only available on CUDA platform - only available on CUDA platform