From eb3f1479f27bfe0035343f96818400129e3743a5 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 16 Nov 2023 21:45:50 -0600 Subject: [PATCH] fix(api): only use optimum's fp16 mode for SDXL export when torch fp16 is enabled --- api/onnx_web/convert/__main__.py | 2 +- api/onnx_web/convert/diffusion/diffusion_xl.py | 2 +- api/onnx_web/diffusers/load.py | 14 +++++++------- api/onnx_web/server/context.py | 5 ++++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index a969ca4c..5cbe7f07 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -599,7 +599,7 @@ def main(args=None) -> int: logger.info("CLI arguments: %s", args) server = ConversionContext.from_environ() - server.half = args.half or "onnx-fp16" in server.optimizations + server.half = args.half or server.has_optimization("onnx-fp16") server.opset = args.opset server.token = args.token logger.info( diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index f6413cb7..f270f1d5 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -81,7 +81,7 @@ def convert_diffusion_diffusers_xl( output=dest_path, task="stable-diffusion-xl", device=device, - fp16=conversion.half, + fp16=conversion.has_optimization("torch-fp16"), # optimum's fp16 mode only works on CUDA or ROCm framework="pt", ) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 162a8699..894211fa 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -563,8 +563,8 @@ def optimize_pipeline( pipe: StableDiffusionPipeline, ) -> None: if ( - "diffusers-attention-slicing" in server.optimizations - or "diffusers-attention-slicing-auto" in server.optimizations + server.has_optimization("diffusers-attention-slicing") + or server.has_optimization("diffusers-attention-slicing-auto") ): logger.debug("enabling auto attention slicing on SD pipeline") try: @@ -572,28 +572,28 @@ def optimize_pipeline( except Exception as e: logger.warning("error while enabling auto attention slicing: %s", e) - if "diffusers-attention-slicing-max" in server.optimizations: + if server.has_optimization("diffusers-attention-slicing-max"): logger.debug("enabling max attention slicing on SD pipeline") try: pipe.enable_attention_slicing(slice_size="max") except Exception as e: logger.warning("error while enabling max attention slicing: %s", e) - if "diffusers-vae-slicing" in server.optimizations: + if server.has_optimization("diffusers-vae-slicing"): logger.debug("enabling VAE slicing on SD pipeline") try: pipe.enable_vae_slicing() except Exception as e: logger.warning("error while enabling VAE slicing: %s", e) - if "diffusers-cpu-offload-sequential" in server.optimizations: + if server.has_optimization("diffusers-cpu-offload-sequential"): logger.debug("enabling sequential CPU offload on SD pipeline") try: pipe.enable_sequential_cpu_offload() except Exception as e: logger.warning("error while enabling sequential CPU offload: %s", e) - elif "diffusers-cpu-offload-model" in server.optimizations: + elif server.has_optimization("diffusers-cpu-offload-model"): # TODO: check for accelerate logger.debug("enabling model CPU offload on SD pipeline") try: @@ -601,7 +601,7 @@ def optimize_pipeline( except Exception as e: logger.warning("error while enabling model CPU offload: %s", e) - if "diffusers-memory-efficient-attention" in server.optimizations: + if server.has_optimization("diffusers-memory-efficient-attention"): # TODO: check for xformers logger.debug("enabling memory efficient attention for SD pipeline") try: diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 98bf6af5..d4118205 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -129,8 +129,11 @@ class ServerContext: def has_feature(self, flag: str) -> bool: return flag in self.feature_flags + def has_optimization(self, opt: str) -> bool: + return opt in self.optimizations + def torch_dtype(self): - if "torch-fp16" in self.optimizations: + if self.has_optimization("torch-fp16"): return torch.float16 else: return torch.float32