diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 5c54b5a1..7f6d6ca8 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -203,17 +203,26 @@ def check_paths(context: ServerContext): makedirs(context.output_path) +def get_model_name(model: str) -> str: + base = path.basename(model) + (file, _ext) = path.splitext(base) + return file + + def load_models(context: ServerContext): global diffusion_models global correction_models global upscaling_models - diffusion_models = glob(path.join(context.model_path, 'diffusion-*')) - diffusion_models.extend(glob(path.join(context.model_path, 'stable-diffusion-*'))) - - correction_models = glob(path.join(context.model_path, 'correction-*')) - upscaling_models = glob(path.join(context.model_path, 'upscaling-*')) + diffusion_models = [get_model_name(f) for f in glob( + path.join(context.model_path, 'diffusion-*'))] + diffusion_models.extend([ + get_model_name(f) for f in glob(path.join(context.model_path, 'stable-diffusion-*'))]) + correction_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, 'correction-*'))] + upscaling_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, 'upscaling-*'))] def load_params(context: ServerContext):