From 1351b2f3ff69e5866ea49622db612ec02a6b2576 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 6 Oct 2023 19:01:00 -0500 Subject: [PATCH] fix(api): allow SDXL VAE in any supported tensor format, ensure new SDXL models get hash file --- .../convert/diffusion/diffusion_xl.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 16081b53..54f37752 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -10,7 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16 from optimum.exporters.onnx import main_export from ...constants import ONNX_MODEL -from ..utils import ConversionContext +from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext logger = getLogger(__name__) @@ -42,13 +42,14 @@ def convert_diffusion_diffusers_xl( "converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path ) - if "hash" in model and not path.exists(model_hash): - logger.info("ONNX model does not have hash file, adding one") - with open(model_hash, "w") as f: - f.write(model["hash"]) - if path.exists(dest_path) and path.exists(model_index): logger.info("ONNX model already exists, skipping conversion") + + if "hash" in model and not path.exists(model_hash): + logger.info("ONNX model does not have hash file, adding one") + with open(model_hash, "w") as f: + f.write(model["hash"]) + return (False, dest_path) # safetensors -> diffusers directory with torch models @@ -63,7 +64,7 @@ def convert_diffusion_diffusers_xl( if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if replace_vae.endswith(".safetensors"): + if check_ext(replace_vae, RESOLVE_FORMATS): pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: pipeline.vae = AutoencoderKL.from_pretrained(vae_path) @@ -80,6 +81,11 @@ def convert_diffusion_diffusers_xl( framework="pt", ) + if "hash" in model: + logger.debug("adding hash file to ONNX model") + with open(model_hash, "w") as f: + f.write(model["hash"]) + if conversion.half: unet_path = path.join(dest_path, "unet", ONNX_MODEL) infer_shapes_path(unet_path)