diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index 9ab4494b..034fc3c6 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -106,19 +106,19 @@ class ServerContext: model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), - cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", "*"), + cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", default="*"), any_platform=get_boolean( environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM ), - block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS", ""), + block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS"), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), show_progress=get_boolean( environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS ), - optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS", ""), - extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS", ""), + optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS"), + extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS"), job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), memory_limit=memory_limit, admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None), @@ -128,7 +128,7 @@ class ServerContext: worker_retries=int( environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) ), - feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS", ""), + feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS"), plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""), debug=get_boolean(environ, "ONNX_WEB_DEBUG", False), ) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 799e9dd5..9dac0d40 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -41,8 +41,8 @@ def get_boolean(args: Any, key: str, default_value: bool) -> bool: return val.lower() in ("1", "t", "true", "y", "yes") -def get_list(args: Any, key: str) -> List[str]: - return split_list(args.get(key, "")) +def get_list(args: Any, key: str, default = "") -> List[str]: + return split_list(args.get(key, default)) def get_and_clamp_float(