fix(api): trim whitespace from model names because it breaks things (#376)
This commit is contained in:
parent
9c1fcd16fa
commit
4da4cd95a5
|
@ -275,7 +275,7 @@ def convert_diffusion_diffusers(
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
name = model.get("name")
|
name = str(model.get("name")).strip()
|
||||||
source = model.get("source")
|
source = model.get("source")
|
||||||
|
|
||||||
# optional
|
# optional
|
||||||
|
|
|
@ -25,7 +25,7 @@ def convert_diffusion_diffusers_xl(
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
name = model.get("name")
|
name = str(model.get("name")).strip()
|
||||||
source = model.get("source")
|
source = model.get("source")
|
||||||
replace_vae = model.get("vae", None)
|
replace_vae = model.get("vae", None)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
from jsonschema import ValidationError, validate
|
from jsonschema import ValidationError, validate
|
||||||
|
|
||||||
|
from ..convert.utils import fix_diffusion_name
|
||||||
from ..image import ( # mask filters; noise sources
|
from ..image import ( # mask filters; noise sources
|
||||||
mask_filter_gaussian_multiply,
|
mask_filter_gaussian_multiply,
|
||||||
mask_filter_gaussian_screen,
|
mask_filter_gaussian_screen,
|
||||||
|
@ -189,6 +190,9 @@ def load_extras(server: ServerContext):
|
||||||
for model in data[model_type]:
|
for model in data[model_type]:
|
||||||
model_name = model["name"]
|
model_name = model["name"]
|
||||||
|
|
||||||
|
if model_type == "diffusion":
|
||||||
|
model_name = fix_diffusion_name(model_name)
|
||||||
|
|
||||||
if "hash" in model:
|
if "hash" in model:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"collecting hash for model %s from %s",
|
"collecting hash for model %s from %s",
|
||||||
|
|
Loading…
Reference in New Issue