1
0
Fork 0

fix(api): allow SDXL VAE in any supported tensor format, ensure new SDXL models get hash file

This commit is contained in:
Sean Sube 2023-10-06 19:01:00 -05:00
parent 047e58c916
commit 1351b2f3ff
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 13 additions and 7 deletions

View File

@ -10,7 +10,7 @@ from onnxruntime.transformers.float16 import convert_float_to_float16
from optimum.exporters.onnx import main_export
from ...constants import ONNX_MODEL
from ..utils import ConversionContext
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
logger = getLogger(__name__)
@ -42,13 +42,14 @@ def convert_diffusion_diffusers_xl(
"converting Stable Diffusion XL model %s: %s -> %s/", name, source, dest_path
)
if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping conversion")
if "hash" in model and not path.exists(model_hash):
logger.info("ONNX model does not have hash file, adding one")
with open(model_hash, "w") as f:
f.write(model["hash"])
if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping conversion")
return (False, dest_path)
# safetensors -> diffusers directory with torch models
@ -63,7 +64,7 @@ def convert_diffusion_diffusers_xl(
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
if replace_vae.endswith(".safetensors"):
if check_ext(replace_vae, RESOLVE_FORMATS):
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
pipeline.vae = AutoencoderKL.from_pretrained(vae_path)
@ -80,6 +81,11 @@ def convert_diffusion_diffusers_xl(
framework="pt",
)
if "hash" in model:
logger.debug("adding hash file to ONNX model")
with open(model_hash, "w") as f:
f.write(model["hash"])
if conversion.half:
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
infer_shapes_path(unet_path)