1
0
Fork 0

fix(api): trim whitespace from model names because it breaks things (#376)

This commit is contained in:
Sean Sube 2023-12-10 13:59:47 -06:00
parent 9c1fcd16fa
commit 4da4cd95a5
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 6 additions and 2 deletions

View File

@ -275,7 +275,7 @@ def convert_diffusion_diffusers(
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
""" """
name = model.get("name") name = str(model.get("name")).strip()
source = model.get("source") source = model.get("source")
# optional # optional

View File

@ -25,7 +25,7 @@ def convert_diffusion_diffusers_xl(
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
""" """
name = model.get("name") name = str(model.get("name")).strip()
source = model.get("source") source = model.get("source")
replace_vae = model.get("vae", None) replace_vae = model.get("vae", None)

View File

@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
from jsonschema import ValidationError, validate from jsonschema import ValidationError, validate
from ..convert.utils import fix_diffusion_name
from ..image import ( # mask filters; noise sources from ..image import ( # mask filters; noise sources
mask_filter_gaussian_multiply, mask_filter_gaussian_multiply,
mask_filter_gaussian_screen, mask_filter_gaussian_screen,
@ -189,6 +190,9 @@ def load_extras(server: ServerContext):
for model in data[model_type]: for model in data[model_type]:
model_name = model["name"] model_name = model["name"]
if model_type == "diffusion":
model_name = fix_diffusion_name(model_name)
if "hash" in model: if "hash" in model:
logger.debug( logger.debug(
"collecting hash for model %s from %s", "collecting hash for model %s from %s",