diff --git a/api/onnx_web/diffusers/utils.py b/api/onnx_web/diffusers/utils.py index 439828ea..02dcb35c 100644 --- a/api/onnx_web/diffusers/utils.py +++ b/api/onnx_web/diffusers/utils.py @@ -20,7 +20,7 @@ MAX_TOKENS_PER_GROUP = 77 CLIP_TOKEN = compile(r"\") INVERSION_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") LORA_TOKEN = compile(r"\]+):(-?[\.|\d]+)\>") -WILDCARD_TOKEN = compile(r"__([-\w]+)__") +WILDCARD_TOKEN = compile(r"__([-/\w]+)__") INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}") ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)") diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index e3795745..b3e325f0 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -255,17 +255,24 @@ IGNORE_EXTENSIONS = [".crdownload", ".lock", ".tmp"] def list_model_globs( - server: ServerContext, globs: List[str], base_path: Optional[str] = None + server: ServerContext, + globs: List[str], + base_path: Optional[str] = None, + recursive=False, + filename_only=True, ) -> List[str]: + if base_path is None: + base_path = server.model_path + models = [] for pattern in globs: - pattern_path = path.join(base_path or server.model_path, pattern) + pattern_path = path.join(base_path, pattern) logger.debug("loading models from %s", pattern_path) - for name in glob(pattern_path): + for name in glob(pattern_path, recursive=recursive): base = path.basename(name) (file, ext) = path.splitext(base) if ext not in IGNORE_EXTENSIONS: - models.append(file) + models.append(file if filename_only else path.relpath(name, base_path)) unique_models = list(set(models)) unique_models.sort() @@ -456,7 +463,11 @@ def load_wildcards(server: ServerContext) -> None: # simple wildcards wildcard_files = list_model_globs( - server, ["*.txt"], base_path=path.join(server.model_path, "wildcard") + server, + ["**/*.txt"], + base_path=path.join(server.model_path, "wildcard"), + filename_only=False, + recursive=True, ) for file in wildcard_files: @@ -465,6 +476,6 @@ def load_wildcards(server: ServerContext) -> None: lines = [line.strip() for line in lines if not line.startswith("#")] lines = [line for line in lines if len(line) > 0] logger.debug("loading wildcards from %s: %s", file, lines) - wildcard_data[file] = lines + wildcard_data[path.splitext(file)[0]] = lines # TODO: structured wildcards