From 9551e4a9b9c63817cb9967e3b376457c83795184 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 11 Feb 2023 15:53:27 -0600 Subject: [PATCH] feat(api): make any platform optional --- api/onnx_web/serve.py | 11 +++++++---- api/onnx_web/utils.py | 11 ++++++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 6c187d5b..be4bb3b9 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -333,15 +333,18 @@ def load_params(context: ServerContext): config_params = yaml.safe_load(f) if "platform" in config_params and context.default_platform is not None: - logger.info("overriding default platform to %s", context.default_platform) + logger.info("Overriding default platform from environment: %s", context.default_platform) config_platform = config_params.get("platform", {}) config_platform["default"] = context.default_platform -def load_platforms(): +def load_platforms(context: ServerContext): global available_platforms - providers = ["any"].extend(get_available_providers()) + providers = [].extend(get_available_providers()) + + if context.any_platform: + providers.append("any") for potential in platform_providers: if ( @@ -391,7 +394,7 @@ context = ServerContext.from_environ() check_paths(context) load_models(context) load_params(context) -load_platforms() +load_platforms(context) app = Flask(__name__) CORS(app, origins=context.cors_origin) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 473d424c..ca649469 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -19,6 +19,7 @@ class ServerContext: params_path: str = ".", cors_origin: str = "*", num_workers: int = 1, + any_platform: bool = True, block_platforms: List[str] = [], default_platform: str = None, image_format: str = "png", @@ -29,6 +30,7 @@ class ServerContext: self.params_path = params_path self.cors_origin = cors_origin self.num_workers = num_workers + self.any_platform = any_platform self.block_platforms = block_platforms self.default_platform = default_platform self.image_format = image_format @@ -45,6 +47,7 @@ class ServerContext: # others cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)), + any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True), block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), @@ -57,7 +60,13 @@ def base_join(base: str, tail: str) -> str: def is_debug() -> bool: - return environ.get("DEBUG") is not None + return get_boolean(environ, "DEBUG", False) + + +def get_boolean( + args: Any, key: str, default_value: bool +) -> bool: + return (args.get(key, str(default_value)).lower() in ('1', 't', 'true', 'y', 'yes')) def get_and_clamp_float(