From 9e930a91d5c331acf78da0c6676e64727bf38128 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 24 Nov 2023 22:40:22 -0600 Subject: [PATCH] fix(api): load lists without empty items --- api/onnx_web/server/context.py | 17 ++++++++--------- api/onnx_web/utils.py | 9 +++++++++ 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/api/onnx_web/server/context.py b/api/onnx_web/server/context.py index e1af489d..9ab4494b 100644 --- a/api/onnx_web/server/context.py +++ b/api/onnx_web/server/context.py @@ -5,7 +5,7 @@ from typing import List, Optional import torch -from ..utils import get_boolean +from ..utils import get_boolean, get_list from .model_cache import ModelCache logger = getLogger(__name__) @@ -106,20 +106,19 @@ class ServerContext: model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), - # others - cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), + cors_origin=get_list(environ, "ONNX_WEB_CORS_ORIGIN", "*"), any_platform=get_boolean( environ, "ONNX_WEB_ANY_PLATFORM", DEFAULT_ANY_PLATFORM ), - block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), + block_platforms=get_list(environ, "ONNX_WEB_BLOCK_PLATFORMS", ""), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), - image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), + image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", DEFAULT_IMAGE_FORMAT), cache_limit=int(environ.get("ONNX_WEB_CACHE_MODELS", DEFAULT_CACHE_LIMIT)), show_progress=get_boolean( environ, "ONNX_WEB_SHOW_PROGRESS", DEFAULT_SHOW_PROGRESS ), - optimizations=environ.get("ONNX_WEB_OPTIMIZATIONS", "").split(","), - extra_models=environ.get("ONNX_WEB_EXTRA_MODELS", "").split(","), + optimizations=get_list(environ, "ONNX_WEB_OPTIMIZATIONS", ""), + extra_models=get_list(environ, "ONNX_WEB_EXTRA_MODELS", ""), job_limit=int(environ.get("ONNX_WEB_JOB_LIMIT", DEFAULT_JOB_LIMIT)), memory_limit=memory_limit, admin_token=environ.get("ONNX_WEB_ADMIN_TOKEN", None), @@ -129,8 +128,8 @@ class ServerContext: worker_retries=int( environ.get("ONNX_WEB_WORKER_RETRIES", DEFAULT_WORKER_RETRIES) ), - feature_flags=environ.get("ONNX_WEB_FEATURE_FLAGS", "").split(","), - plugins=environ.get("ONNX_WEB_PLUGINS", "").split(","), + feature_flags=get_list(environ, "ONNX_WEB_FEATURE_FLAGS", ""), + plugins=get_list(environ, "ONNX_WEB_PLUGINS", ""), debug=get_boolean(environ, "ONNX_WEB_DEBUG", False), ) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 43547c63..799e9dd5 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -18,6 +18,11 @@ logger = getLogger(__name__) SAFE_CHARS = "._-" +def split_list(val: str) -> List[str]: + parts = [part.strip() for part in val.split(",")] + return [part for part in parts if len(part.strip()) > 0] + + def base_join(base: str, tail: str) -> str: tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") return path.join(base, tail_path) @@ -36,6 +41,10 @@ def get_boolean(args: Any, key: str, default_value: bool) -> bool: return val.lower() in ("1", "t", "true", "y", "yes") +def get_list(args: Any, key: str) -> List[str]: + return split_list(args.get(key, "")) + + def get_and_clamp_float( args: Any, key: str, default_value: float, max_value: float, min_value=0.0 ) -> float: