feat(api): support wildcards in nested folders
This commit is contained in:
parent
1a084eba1c
commit
865b25e6d7
|
@ -20,7 +20,7 @@ MAX_TOKENS_PER_GROUP = 77
|
|||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||
WILDCARD_TOKEN = compile(r"__([-\w]+)__")
|
||||
WILDCARD_TOKEN = compile(r"__([-/\w]+)__")
|
||||
|
||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||
|
|
|
@ -255,17 +255,24 @@ IGNORE_EXTENSIONS = [".crdownload", ".lock", ".tmp"]
|
|||
|
||||
|
||||
def list_model_globs(
|
||||
server: ServerContext, globs: List[str], base_path: Optional[str] = None
|
||||
server: ServerContext,
|
||||
globs: List[str],
|
||||
base_path: Optional[str] = None,
|
||||
recursive=False,
|
||||
filename_only=True,
|
||||
) -> List[str]:
|
||||
if base_path is None:
|
||||
base_path = server.model_path
|
||||
|
||||
models = []
|
||||
for pattern in globs:
|
||||
pattern_path = path.join(base_path or server.model_path, pattern)
|
||||
pattern_path = path.join(base_path, pattern)
|
||||
logger.debug("loading models from %s", pattern_path)
|
||||
for name in glob(pattern_path):
|
||||
for name in glob(pattern_path, recursive=recursive):
|
||||
base = path.basename(name)
|
||||
(file, ext) = path.splitext(base)
|
||||
if ext not in IGNORE_EXTENSIONS:
|
||||
models.append(file)
|
||||
models.append(file if filename_only else path.relpath(name, base_path))
|
||||
|
||||
unique_models = list(set(models))
|
||||
unique_models.sort()
|
||||
|
@ -456,7 +463,11 @@ def load_wildcards(server: ServerContext) -> None:
|
|||
|
||||
# simple wildcards
|
||||
wildcard_files = list_model_globs(
|
||||
server, ["*.txt"], base_path=path.join(server.model_path, "wildcard")
|
||||
server,
|
||||
["**/*.txt"],
|
||||
base_path=path.join(server.model_path, "wildcard"),
|
||||
filename_only=False,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
for file in wildcard_files:
|
||||
|
@ -465,6 +476,6 @@ def load_wildcards(server: ServerContext) -> None:
|
|||
lines = [line.strip() for line in lines if not line.startswith("#")]
|
||||
lines = [line for line in lines if len(line) > 0]
|
||||
logger.debug("loading wildcards from %s: %s", file, lines)
|
||||
wildcard_data[file] = lines
|
||||
wildcard_data[path.splitext(file)[0]] = lines
|
||||
|
||||
# TODO: structured wildcards
|
||||
|
|
Loading…
Reference in New Issue