feat(api): support wildcards in nested folders
This commit is contained in:
parent
1a084eba1c
commit
865b25e6d7
|
@ -20,7 +20,7 @@ MAX_TOKENS_PER_GROUP = 77
|
||||||
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
CLIP_TOKEN = compile(r"\<clip:([-\w]+):(\d+)\>")
|
||||||
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
INVERSION_TOKEN = compile(r"\<inversion:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
LORA_TOKEN = compile(r"\<lora:([^:\>]+):(-?[\.|\d]+)\>")
|
||||||
WILDCARD_TOKEN = compile(r"__([-\w]+)__")
|
WILDCARD_TOKEN = compile(r"__([-/\w]+)__")
|
||||||
|
|
||||||
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
INTERVAL_RANGE = compile(r"(\w+)-{(\d+),(\d+)(?:,(\d+))?}")
|
||||||
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
ALTERNATIVE_RANGE = compile(r"\(([^\)]+)\)")
|
||||||
|
|
|
@ -255,17 +255,24 @@ IGNORE_EXTENSIONS = [".crdownload", ".lock", ".tmp"]
|
||||||
|
|
||||||
|
|
||||||
def list_model_globs(
|
def list_model_globs(
|
||||||
server: ServerContext, globs: List[str], base_path: Optional[str] = None
|
server: ServerContext,
|
||||||
|
globs: List[str],
|
||||||
|
base_path: Optional[str] = None,
|
||||||
|
recursive=False,
|
||||||
|
filename_only=True,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
if base_path is None:
|
||||||
|
base_path = server.model_path
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
for pattern in globs:
|
for pattern in globs:
|
||||||
pattern_path = path.join(base_path or server.model_path, pattern)
|
pattern_path = path.join(base_path, pattern)
|
||||||
logger.debug("loading models from %s", pattern_path)
|
logger.debug("loading models from %s", pattern_path)
|
||||||
for name in glob(pattern_path):
|
for name in glob(pattern_path, recursive=recursive):
|
||||||
base = path.basename(name)
|
base = path.basename(name)
|
||||||
(file, ext) = path.splitext(base)
|
(file, ext) = path.splitext(base)
|
||||||
if ext not in IGNORE_EXTENSIONS:
|
if ext not in IGNORE_EXTENSIONS:
|
||||||
models.append(file)
|
models.append(file if filename_only else path.relpath(name, base_path))
|
||||||
|
|
||||||
unique_models = list(set(models))
|
unique_models = list(set(models))
|
||||||
unique_models.sort()
|
unique_models.sort()
|
||||||
|
@ -456,7 +463,11 @@ def load_wildcards(server: ServerContext) -> None:
|
||||||
|
|
||||||
# simple wildcards
|
# simple wildcards
|
||||||
wildcard_files = list_model_globs(
|
wildcard_files = list_model_globs(
|
||||||
server, ["*.txt"], base_path=path.join(server.model_path, "wildcard")
|
server,
|
||||||
|
["**/*.txt"],
|
||||||
|
base_path=path.join(server.model_path, "wildcard"),
|
||||||
|
filename_only=False,
|
||||||
|
recursive=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for file in wildcard_files:
|
for file in wildcard_files:
|
||||||
|
@ -465,6 +476,6 @@ def load_wildcards(server: ServerContext) -> None:
|
||||||
lines = [line.strip() for line in lines if not line.startswith("#")]
|
lines = [line.strip() for line in lines if not line.startswith("#")]
|
||||||
lines = [line for line in lines if len(line) > 0]
|
lines = [line for line in lines if len(line) > 0]
|
||||||
logger.debug("loading wildcards from %s: %s", file, lines)
|
logger.debug("loading wildcards from %s: %s", file, lines)
|
||||||
wildcard_data[file] = lines
|
wildcard_data[path.splitext(file)[0]] = lines
|
||||||
|
|
||||||
# TODO: structured wildcards
|
# TODO: structured wildcards
|
||||||
|
|
Loading…
Reference in New Issue