1
0
Fork 0

fix(api): filter out temporary files from model lists (#271)

This commit is contained in:
Sean Sube 2023-03-19 23:26:05 -05:00
parent b99c8c8bae
commit 19712262e6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 8 additions and 7 deletions

View File

@ -106,11 +106,6 @@ def get_config_value(key: str, subkey: str = "default", default=None):
return config_params.get(key, {}).get(subkey, default) return config_params.get(key, {}).get(subkey, default)
def get_model_name(model: str) -> str:
base = path.basename(model)
(file, _ext) = path.splitext(base)
return file
def load_extras(context: ServerContext): def load_extras(context: ServerContext):
""" """
@ -202,6 +197,9 @@ def load_extras(context: ServerContext):
extra_strings = strings extra_strings = strings
IGNORE_EXTENSIONS = ["crdownload", "lock", "tmp"]
def list_model_globs( def list_model_globs(
context: ServerContext, globs: List[str], base_path: Optional[str] = None context: ServerContext, globs: List[str], base_path: Optional[str] = None
) -> List[str]: ) -> List[str]:
@ -209,8 +207,11 @@ def list_model_globs(
for pattern in globs: for pattern in globs:
pattern_path = path.join(base_path or context.model_path, pattern) pattern_path = path.join(base_path or context.model_path, pattern)
logger.debug("loading models from %s", pattern_path) logger.debug("loading models from %s", pattern_path)
for name in glob(pattern_path):
models.extend([get_model_name(f) for f in glob(pattern_path)]) base = path.basename(name)
(file, ext) = path.splitext(base)
if ext not in IGNORE_EXTENSIONS:
models.append(file)
unique_models = list(set(models)) unique_models = list(set(models))
unique_models.sort() unique_models.sort()