diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index aea560d9..296f7351 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -275,7 +275,7 @@ def convert_diffusion_diffusers( """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ - name = model.get("name") + name = str(model.get("name")).strip() source = model.get("source") # optional diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index d9319596..c09b9440 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -25,7 +25,7 @@ def convert_diffusion_diffusers_xl( """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ - name = model.get("name") + name = str(model.get("name")).strip() source = model.get("source") replace_vae = model.get("vae", None) diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 6bf1de2d..2dd37157 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union import torch from jsonschema import ValidationError, validate +from ..convert.utils import fix_diffusion_name from ..image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, @@ -189,6 +190,9 @@ def load_extras(server: ServerContext): for model in data[model_type]: model_name = model["name"] + if model_type == "diffusion": + model_name = fix_diffusion_name(model_name) + if "hash" in model: logger.debug( "collecting hash for model %s from %s",