fix(api): filter out temporary files from model lists (#271)
This commit is contained in:
parent
b99c8c8bae
commit
19712262e6
|
@ -106,11 +106,6 @@ def get_config_value(key: str, subkey: str = "default", default=None):
|
|||
return config_params.get(key, {}).get(subkey, default)
|
||||
|
||||
|
||||
def get_model_name(model: str) -> str:
|
||||
base = path.basename(model)
|
||||
(file, _ext) = path.splitext(base)
|
||||
return file
|
||||
|
||||
|
||||
def load_extras(context: ServerContext):
|
||||
"""
|
||||
|
@ -202,6 +197,9 @@ def load_extras(context: ServerContext):
|
|||
extra_strings = strings
|
||||
|
||||
|
||||
IGNORE_EXTENSIONS = ["crdownload", "lock", "tmp"]
|
||||
|
||||
|
||||
def list_model_globs(
|
||||
context: ServerContext, globs: List[str], base_path: Optional[str] = None
|
||||
) -> List[str]:
|
||||
|
@ -209,8 +207,11 @@ def list_model_globs(
|
|||
for pattern in globs:
|
||||
pattern_path = path.join(base_path or context.model_path, pattern)
|
||||
logger.debug("loading models from %s", pattern_path)
|
||||
|
||||
models.extend([get_model_name(f) for f in glob(pattern_path)])
|
||||
for name in glob(pattern_path):
|
||||
base = path.basename(name)
|
||||
(file, ext) = path.splitext(base)
|
||||
if ext not in IGNORE_EXTENSIONS:
|
||||
models.append(file)
|
||||
|
||||
unique_models = list(set(models))
|
||||
unique_models.sort()
|
||||
|
|
Loading…
Reference in New Issue