feat(api): add flag for ORT float16 optimizations
This commit is contained in:
parent
e4b59f0d9a
commit
1c631c28d3
|
@ -353,7 +353,8 @@ timestep_dtype = None
|
|||
|
||||
|
||||
class UNetWrapper(object):
|
||||
def __init__(self, wrapped):
|
||||
def __init__(self, server, wrapped):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(self, sample=None, timestep=None, encoder_hidden_states=None):
|
||||
|
@ -361,8 +362,12 @@ class UNetWrapper(object):
|
|||
timestep_dtype = timestep.dtype
|
||||
|
||||
logger.trace("UNet parameter types: %s, %s", sample.dtype, timestep.dtype)
|
||||
if sample.dtype != timestep.dtype:
|
||||
logger.info("converting UNet sample dtype")
|
||||
if "onnx-fp16" in self.server.optimizations:
|
||||
logger.info("converting UNet sample to ONNX fp16")
|
||||
sample = sample.astype(np.float16)
|
||||
encoder_hidden_states = encoder_hidden_states.astype(np.float16)
|
||||
elif sample.dtype != timestep.dtype:
|
||||
logger.info("converting UNet sample to timestep dtype")
|
||||
sample = sample.astype(timestep.dtype)
|
||||
|
||||
return self.wrapped(
|
||||
|
@ -376,7 +381,8 @@ class UNetWrapper(object):
|
|||
|
||||
|
||||
class VAEWrapper(object):
|
||||
def __init__(self, wrapped):
|
||||
def __init__(self, server, wrapped):
|
||||
self.server = server
|
||||
self.wrapped = wrapped
|
||||
|
||||
def __call__(self, latent_sample=None):
|
||||
|
@ -404,5 +410,5 @@ def patch_pipeline(
|
|||
original_unet = pipe.unet
|
||||
original_vae = pipe.vae_decoder
|
||||
|
||||
pipe.unet = UNetWrapper(original_unet)
|
||||
pipe.vae_decoder = VAEWrapper(original_vae)
|
||||
pipe.unet = UNetWrapper(server, original_unet)
|
||||
pipe.vae_decoder = VAEWrapper(server, original_vae)
|
||||
|
|
|
@ -99,8 +99,12 @@ Others:
|
|||
- not available for ONNX pipelines (most of them)
|
||||
- https://huggingface.co/docs/diffusers/optimization/fp16#sliced-vae-decode-for-larger-batches
|
||||
- `onnx-*`
|
||||
- `onnx-low-memory`
|
||||
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
|
||||
- `onnx-deterministic-compute`
|
||||
- enable ONNX deterministic compute
|
||||
- `onnx-fp16`
|
||||
- force 16-bit floating point values when running pipelines
|
||||
- use with https://github.com/microsoft/onnxruntime/tree/main/onnxruntime/python/tools/transformers/models/stable_diffusion#optimize-onnx-pipeline
|
||||
and the `--float16` flag
|
||||
- `onnx-graph-*`
|
||||
- `onnx-graph-disable`
|
||||
- disable all ONNX graph optimizations
|
||||
|
@ -108,11 +112,11 @@ Others:
|
|||
- enable basic ONNX graph optimizations
|
||||
- `onnx-graph-all`
|
||||
- enable all ONNX graph optimizations
|
||||
- `onnx-deterministic-compute`
|
||||
- enable ONNX deterministic compute
|
||||
- `onnx-low-memory`
|
||||
- disable ONNX features that allocate more memory than is strictly required or keep memory after use
|
||||
- `torch-*`
|
||||
- `torch-fp16`
|
||||
- use 16-bit floating point values when loading and converting pipelines
|
||||
- use 16-bit floating point values when converting and running pipelines
|
||||
- applies during conversion as well
|
||||
- only available on CUDA platform
|
||||
|
||||
|
|
Loading…
Reference in New Issue