From 46d9fc0dd40a2b1c657dd7733b30ffa11101bb01 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 8 Dec 2023 18:49:18 -0600 Subject: [PATCH] fix(api): make sure diffusion models have a valid prefix --- api/onnx_web/convert/__main__.py | 20 +++++++++++++++++++ .../convert/diffusion/diffusion_xl.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 5cbe7f07..fb302ae9 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -265,6 +265,20 @@ def fetch_model( return source, False +DIFFUSION_PREFIX = ["diffusion-", "diffusion/", "diffusion\\", "stable-diffusion-"] + + +def fix_diffusion_name(name: str): + if not any([name.startswith(prefix) for prefix in DIFFUSION_PREFIX]): + logger.warning( + "diffusion models must have names starting with diffusion- to be recognized by the server: %s does not match", + name, + ) + return f"diffusion-{name}" + + return name + + def convert_models(conversion: ConversionContext, args, models: Models): model_errors = [] @@ -351,6 +365,12 @@ def convert_models(conversion: ConversionContext, args, models: Models): if name in args.skip: logger.info("skipping model: %s", name) else: + # fix up entries with missing prefixes + name = fix_diffusion_name(name) + if name != model["name"]: + # update the model in-memory if the name changed + model["name"] = name + model_format = source_format(model) try: diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 7a03b700..18fa8493 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -64,7 +64,7 @@ def convert_diffusion_diffusers_xl( if replace_vae is not None: vae_path = path.join(conversion.model_path, replace_vae) - if check_ext(replace_vae, RESOLVE_FORMATS): + if check_ext(vae_path, RESOLVE_FORMATS): logger.debug("loading VAE from single tensor file: %s", vae_path) pipeline.vae = AutoencoderKL.from_single_file(vae_path) else: