1
0
Fork 0

feat(api): load wildcards from markup files

This commit is contained in:
Sean Sube 2023-07-12 18:56:48 -05:00
parent 00fc584c99
commit 5df9aa2af7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 21 additions and 4 deletions

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from functools import cmp_to_key
from glob import glob
from logging import getLogger
@ -90,7 +91,7 @@ correction_models: List[str] = []
diffusion_models: List[str] = []
network_models: List[NetworkModel] = []
upscaling_models: List[str] = []
wildcard_data: Dict[str, List[str]] = {}
wildcard_data: Dict[str, List[str]] = defaultdict(list)
# Loaded from extra_models
extra_hashes: Dict[str, str] = {}
@ -461,11 +462,13 @@ def load_platforms(server: ServerContext) -> None:
def load_wildcards(server: ServerContext) -> None:
global wildcard_data
wildcard_path = path.join(server.model_path, "wildcard")
# simple wildcards
wildcard_files = list_model_globs(
server,
["**/*.txt"],
base_path=path.join(server.model_path, "wildcard"),
base_path=wildcard_path,
filename_only=False,
recursive=True,
)
@ -476,6 +479,20 @@ def load_wildcards(server: ServerContext) -> None:
lines = [line.strip() for line in lines if not line.startswith("#")]
lines = [line for line in lines if len(line) > 0]
logger.debug("loading wildcards from %s: %s", file, lines)
wildcard_data[path.splitext(file)[0]] = lines
wildcard_data[path.splitext(file)[0]].extend(lines)
# TODO: structured wildcards
structured_files = list_model_globs(
server,
["**/*.json", "**/*.yaml"],
base_path=wildcard_path,
filename_only=False,
recursive=True,
)
for file in structured_files:
data = load_config(path.join(wildcard_path, file))
logger.debug("loading structured wildcards from %s: %s", file, data)
for key, values in data.items():
if isinstance(values, list):
wildcard_data[key].extend(values)