diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index b9dc4394..ae55669c 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -24,7 +24,7 @@ REGION_TOKEN = compile( r"\]+)\>" ) RESEED_TOKEN = compile(r"\") -WILDCARD_TOKEN = compile(r"__([-/\\\w]+)__") +WILDCARD_TOKEN = compile(r"__([-/\\\w\.]+)__") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index 2dd37157..defd2ef2 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -2,7 +2,7 @@ from collections import defaultdict from functools import cmp_to_key from glob import glob from logging import getLogger -from os import path +from os import path, sep from typing import Any, Dict, List, Optional, Union import torch @@ -516,7 +516,19 @@ def load_wildcards(server: ServerContext) -> None: for file in structured_files: data = load_config(path.join(wildcard_path, file)) logger.debug("loading structured wildcards from %s: %s", file, data) + parse_wildcards(data, root_key=path.splitext(file)[0]) - for key, values in data.items(): - if isinstance(values, list): - wildcard_data[key].extend(values) + +def parse_wildcards(data: Any, root_key: Optional[str]=None) -> None: + global wildcard_data + + for key, values in data.items(): + if root_key is not None: + key=f"{root_key}{sep}{key}" + + if isinstance(values, dict): + parse_wildcards(values, root_key=key) + elif isinstance(values, list): + wildcard_data[key].extend(values) + else: + logger.warning("unable to parse key: %s from wildcard path: %s", key, root_key)