1
0
Fork 0

feat(api): add flag for ORT float16 optimizations

This commit is contained in:
Sean Sube 2023-03-19 11:59:35 -05:00
parent e4b59f0d9a
commit 1c631c28d3
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 21 additions and 11 deletions

View File

@ -353,7 +353,8 @@ timestep_dtype = None
class UNetWrapper(object):
def __init__(self, wrapped):
def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
@ -361,8 +362,12 @@ class UNetWrapper(object):
timestep_dtype = timestep.dtype
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
if sample.dtype != timestep.dtype:
logger.info("converting UNet sample dtype")
if "onnx-fp16" in self.server.optimizations:
logger.info("converting UNet sample to ONNX fp16")
sample = sample.astype(np.float16)
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
elif sample.dtype != timestep.dtype:
logger.info("converting UNet sample to timestep dtype")
sample = sample.astype(timestep.dtype)
return self.wrapped(
@ -376,7 +381,8 @@ class UNetWrapper(object):
class VAEWrapper(object):
def __init__(self, wrapped):
def __init__(self, server, wrapped):
self.server = server
self.wrapped = wrapped
def __call__(self, latent_sample=None):
@ -404,5 +410,5 @@ def patch_pipeline(
original_unet = pipe.unet
original_vae = pipe.vae_decoder
pipe.unet = UNetWrapper(original_unet)
pipe.vae_decoder = VAEWrapper(original_vae)
pipe.unet = UNetWrapper(server, original_unet)
pipe.vae_decoder = VAEWrapper(server, original_vae)

View File

@ -99,8 +99,12 @@ Others:
- not available for ONNX pipelines (most of them)
- https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches
- `onnx-*`
- `onnx-low-memory`
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
- `onnx-deterministic-compute`
- enable ONNX deterministic compute
- `onnx-fp16`
- force 16-bit floating point values when running pipelines
- use with https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/stable_diffusion#optimize-onnx-pipeline
and the `--float16` flag
- `onnx-graph-*`
- `onnx-graph-disable`
- disable all ONNX graph optimizations
@ -108,11 +112,11 @@ Others:
- enable basic ONNX graph optimizations
- `onnx-graph-all`
- enable all ONNX graph optimizations
- `onnx-deterministic-compute`
- enable ONNX deterministic compute
- `onnx-low-memory`
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
- `torch-*`
- `torch-fp16`
- use 16-bit floating point values when loading and converting pipelines
- use 16-bit floating point values when converting and running pipelines
- applies during conversion as well
- only available on CUDA platform