diff --git a/api/onnx_web/server/load.py b/api/onnx_web/server/load.py index a04ac3c2..cef4d5ac 100644 --- a/api/onnx_web/server/load.py +++ b/api/onnx_web/server/load.py @@ -106,11 +106,6 @@ def get_config_value(key: str, subkey: str = "default", default=None): return config_params.get(key, {}).get(subkey, default) -def get_model_name(model: str) -> str: - base = path.basename(model) - (file, _ext) = path.splitext(base) - return file - def load_extras(context: ServerContext): """ @@ -202,6 +197,9 @@ def load_extras(context: ServerContext): extra_strings = strings +IGNORE_EXTENSIONS = ["crdownload", "lock", "tmp"] + + def list_model_globs( context: ServerContext, globs: List[str], base_path: Optional[str] = None ) -> List[str]: @@ -209,8 +207,11 @@ def list_model_globs( for pattern in globs: pattern_path = path.join(base_path or context.model_path, pattern) logger.debug("loading models from %s", pattern_path) - - models.extend([get_model_name(f) for f in glob(pattern_path)]) + for name in glob(pattern_path): + base = path.basename(name) + (file, ext) = path.splitext(base) + if ext not in IGNORE_EXTENSIONS: + models.append(file) unique_models = list(set(models)) unique_models.sort()