feat(api): add flag for ORT float16 optimizations
This commit is contained in:
parent
e4b59f0d9a
commit
1c631c28d3
|
@ -353,7 +353,8 @@ timestep_dtype = None
|
||||||
|
|
||||||
|
|
||||||
class UNetWrapper(object):
|
class UNetWrapper(object):
|
||||||
def __init__(self, wrapped):
|
def __init__(self, server, wrapped):
|
||||||
|
self.server = server
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
|
|
||||||
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
|
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
|
||||||
|
@ -361,8 +362,12 @@ class UNetWrapper(object):
|
||||||
timestep_dtype = timestep.dtype
|
timestep_dtype = timestep.dtype
|
||||||
|
|
||||||
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||||
if sample.dtype != timestep.dtype:
|
if "onnx-fp16" in self.server.optimizations:
|
||||||
logger.info("converting UNet sample dtype")
|
logger.info("converting UNet sample to ONNX fp16")
|
||||||
|
sample = sample.astype(np.float16)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
|
||||||
|
elif sample.dtype != timestep.dtype:
|
||||||
|
logger.info("converting UNet sample to timestep dtype")
|
||||||
sample = sample.astype(timestep.dtype)
|
sample = sample.astype(timestep.dtype)
|
||||||
|
|
||||||
return self.wrapped(
|
return self.wrapped(
|
||||||
|
@ -376,7 +381,8 @@ class UNetWrapper(object):
|
||||||
|
|
||||||
|
|
||||||
class VAEWrapper(object):
|
class VAEWrapper(object):
|
||||||
def __init__(self, wrapped):
|
def __init__(self, server, wrapped):
|
||||||
|
self.server = server
|
||||||
self.wrapped = wrapped
|
self.wrapped = wrapped
|
||||||
|
|
||||||
def __call__(self, latent_sample=None):
|
def __call__(self, latent_sample=None):
|
||||||
|
@ -404,5 +410,5 @@ def patch_pipeline(
|
||||||
original_unet = pipe.unet
|
original_unet = pipe.unet
|
||||||
original_vae = pipe.vae_decoder
|
original_vae = pipe.vae_decoder
|
||||||
|
|
||||||
pipe.unet = UNetWrapper(original_unet)
|
pipe.unet = UNetWrapper(server, original_unet)
|
||||||
pipe.vae_decoder = VAEWrapper(original_vae)
|
pipe.vae_decoder = VAEWrapper(server, original_vae)
|
||||||
|
|
|
@ -99,8 +99,12 @@ Others:
|
||||||
- not available for ONNX pipelines (most of them)
|
- not available for ONNX pipelines (most of them)
|
||||||
- https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches
|
- https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches
|
||||||
- `onnx-*`
|
- `onnx-*`
|
||||||
- `onnx-low-memory`
|
- `onnx-deterministic-compute`
|
||||||
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
|
- enable ONNX deterministic compute
|
||||||
|
- `onnx-fp16`
|
||||||
|
- force 16-bit floating point values when running pipelines
|
||||||
|
- use with https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/stable_diffusion#optimize-onnx-pipeline
|
||||||
|
and the `--float16` flag
|
||||||
- `onnx-graph-*`
|
- `onnx-graph-*`
|
||||||
- `onnx-graph-disable`
|
- `onnx-graph-disable`
|
||||||
- disable all ONNX graph optimizations
|
- disable all ONNX graph optimizations
|
||||||
|
@ -108,11 +112,11 @@ Others:
|
||||||
- enable basic ONNX graph optimizations
|
- enable basic ONNX graph optimizations
|
||||||
- `onnx-graph-all`
|
- `onnx-graph-all`
|
||||||
- enable all ONNX graph optimizations
|
- enable all ONNX graph optimizations
|
||||||
- `onnx-deterministic-compute`
|
- `onnx-low-memory`
|
||||||
- enable ONNX deterministic compute
|
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
|
||||||
- `torch-*`
|
- `torch-*`
|
||||||
- `torch-fp16`
|
- `torch-fp16`
|
||||||
- use 16-bit floating point values when loading and converting pipelines
|
- use 16-bit floating point values when converting and running pipelines
|
||||||
- applies during conversion as well
|
- applies during conversion as well
|
||||||
- only available on CUDA platform
|
- only available on CUDA platform
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue