From 1c631c28d3571827442b083cbf1d1345b6cc19c7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 11:59:35 -0500 Subject: [PATCH] feat(api): add flag for ORT float16 optimizations --- api/onnx_web/diffusers/load.py | 18 ++++++++++++------ docs/server-admin.md | 14 +++++++++----- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index a90e58b9..c96bb31c 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -353,7 +353,8 @@ timestep_dtype = None class UNetWrapper(object): - def __init__(self, wrapped): + def __init__(self, server, wrapped): + self.server = server self.wrapped = wrapped def __call__(self, sample=None, timestep=None, encoder_hidden_states=None): @@ -361,8 +362,12 @@ class UNetWrapper(object): timestep_dtype = timestep.dtype logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype) - if sample.dtype != timestep.dtype: - logger.info("converting UNet sample dtype") + if "onnx-fp16" in self.server.optimizations: + logger.info("converting UNet sample to ONNX fp16") + sample = sample.astype(np.float16) + encoder_hidden_states = encoder_hidden_states.astype(np.float16) + elif sample.dtype != timestep.dtype: + logger.info("converting UNet sample to timestep dtype") sample = sample.astype(timestep.dtype) return self.wrapped( @@ -376,7 +381,8 @@ class UNetWrapper(object): class VAEWrapper(object): - def __init__(self, wrapped): + def __init__(self, server, wrapped): + self.server = server self.wrapped = wrapped def __call__(self, latent_sample=None): @@ -404,5 +410,5 @@ def patch_pipeline( original_unet = pipe.unet original_vae = pipe.vae_decoder - pipe.unet = UNetWrapper(original_unet) - pipe.vae_decoder = VAEWrapper(original_vae) + pipe.unet = UNetWrapper(server, original_unet) + pipe.vae_decoder = VAEWrapper(server, original_vae) diff --git a/docs/server-admin.md b/docs/server-admin.md index ff1a07ab..cc1a6434 100644 --- a/docs/server-admin.md +++ b/docs/server-admin.md @@ -99,8 +99,12 @@ Others: - not available for ONNX pipelines (most of them) - https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches - `onnx-*` - - `onnx-low-memory` - - disable ONNX features that allocate more memory than is strictly required or keep memory after use + - `onnx-deterministic-compute` + - enable ONNX deterministic compute + - `onnx-fp16` + - force 16-bit floating point values when running pipelines + - use with https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/stable_diffusion#optimize-onnx-pipeline + and the `--float16` flag - `onnx-graph-*` - `onnx-graph-disable` - disable all ONNX graph optimizations @@ -108,11 +112,11 @@ Others: - enable basic ONNX graph optimizations - `onnx-graph-all` - enable all ONNX graph optimizations - - `onnx-deterministic-compute` - - enable ONNX deterministic compute + - `onnx-low-memory` + - disable ONNX features that allocate more memory than is strictly required or keep memory after use - `torch-*` - `torch-fp16` - - use 16-bit floating point values when loading and converting pipelines + - use 16-bit floating point values when converting and running pipelines - applies during conversion as well - only available on CUDA platform