1
0
Fork 0

fix(api): only use optimum's fp16 mode for SDXL export when torch fp16 is enabled

This commit is contained in:
Sean Sube 2023-11-16 21:45:50 -06:00
parent b31227ecb3
commit eb3f1479f2
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 13 additions and 10 deletions

View File

@ -599,7 +599,7 @@ def main(args=None) -> int:
logger.info("CLI arguments: %s", args) logger.info("CLI arguments: %s", args)
server = ConversionContext.from_environ() server = ConversionContext.from_environ()
server.half = args.half or "onnx-fp16" in server.optimizations server.half = args.half or server.has_optimization("onnx-fp16")
server.opset = args.opset server.opset = args.opset
server.token = args.token server.token = args.token
logger.info( logger.info(

View File

@ -81,7 +81,7 @@ def convert_diffusion_diffusers_xl(
output=dest_path, output=dest_path,
task="stable-diffusion-xl", task="stable-diffusion-xl",
device=device, device=device,
fp16=conversion.half, fp16=conversion.has_optimization("torch-fp16"), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt", framework="pt",
) )

View File

@ -563,8 +563,8 @@ def optimize_pipeline(
pipe: StableDiffusionPipeline, pipe: StableDiffusionPipeline,
) -> None: ) -> None:
if ( if (
"diffusers-attention-slicing" in server.optimizations server.has_optimization("diffusers-attention-slicing")
or "diffusers-attention-slicing-auto" in server.optimizations or server.has_optimization("diffusers-attention-slicing-auto")
): ):
logger.debug("enabling auto attention slicing on SD pipeline") logger.debug("enabling auto attention slicing on SD pipeline")
try: try:
@ -572,28 +572,28 @@ def optimize_pipeline(
except Exception as e: except Exception as e:
logger.warning("error while enabling auto attention slicing: %s", e) logger.warning("error while enabling auto attention slicing: %s", e)
if "diffusers-attention-slicing-max" in server.optimizations: if server.has_optimization("diffusers-attention-slicing-max"):
logger.debug("enabling max attention slicing on SD pipeline") logger.debug("enabling max attention slicing on SD pipeline")
try: try:
pipe.enable_attention_slicing(slice_size="max") pipe.enable_attention_slicing(slice_size="max")
except Exception as e: except Exception as e:
logger.warning("error while enabling max attention slicing: %s", e) logger.warning("error while enabling max attention slicing: %s", e)
if "diffusers-vae-slicing" in server.optimizations: if server.has_optimization("diffusers-vae-slicing"):
logger.debug("enabling VAE slicing on SD pipeline") logger.debug("enabling VAE slicing on SD pipeline")
try: try:
pipe.enable_vae_slicing() pipe.enable_vae_slicing()
except Exception as e: except Exception as e:
logger.warning("error while enabling VAE slicing: %s", e) logger.warning("error while enabling VAE slicing: %s", e)
if "diffusers-cpu-offload-sequential" in server.optimizations: if server.has_optimization("diffusers-cpu-offload-sequential"):
logger.debug("enabling sequential CPU offload on SD pipeline") logger.debug("enabling sequential CPU offload on SD pipeline")
try: try:
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
except Exception as e: except Exception as e:
logger.warning("error while enabling sequential CPU offload: %s", e) logger.warning("error while enabling sequential CPU offload: %s", e)
elif "diffusers-cpu-offload-model" in server.optimizations: elif server.has_optimization("diffusers-cpu-offload-model"):
# TODO: check for accelerate # TODO: check for accelerate
logger.debug("enabling model CPU offload on SD pipeline") logger.debug("enabling model CPU offload on SD pipeline")
try: try:
@ -601,7 +601,7 @@ def optimize_pipeline(
except Exception as e: except Exception as e:
logger.warning("error while enabling model CPU offload: %s", e) logger.warning("error while enabling model CPU offload: %s", e)
if "diffusers-memory-efficient-attention" in server.optimizations: if server.has_optimization("diffusers-memory-efficient-attention"):
# TODO: check for xformers # TODO: check for xformers
logger.debug("enabling memory efficient attention for SD pipeline") logger.debug("enabling memory efficient attention for SD pipeline")
try: try:

View File

@ -129,8 +129,11 @@ class ServerContext:
def has_feature(self, flag: str) -> bool: def has_feature(self, flag: str) -> bool:
return flag in self.feature_flags return flag in self.feature_flags
def has_optimization(self, opt: str) -> bool:
return opt in self.optimizations
def torch_dtype(self): def torch_dtype(self):
if "torch-fp16" in self.optimizations: if self.has_optimization("torch-fp16"):
return torch.float16 return torch.float16
else: else:
return torch.float32 return torch.float32