1
0
Fork 0

feat(api): make any platform optional

This commit is contained in:
Sean Sube 2023-02-11 15:53:27 -06:00
parent d7383d1101
commit 9551e4a9b9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 17 additions and 5 deletions

View File

@ -333,15 +333,18 @@ def load_params(context: ServerContext):
config_params = yaml.safe_load(f)
if "platform" in config_params and context.default_platform is not None:
logger.info("overriding default platform to %s", context.default_platform)
logger.info("Overriding default platform from environment: %s", context.default_platform)
config_platform = config_params.get("platform", {})
config_platform["default"] = context.default_platform
def load_platforms():
def load_platforms(context: ServerContext):
global available_platforms
providers = ["any"].extend(get_available_providers())
providers = [].extend(get_available_providers())
if context.any_platform:
providers.append("any")
for potential in platform_providers:
if (
@ -391,7 +394,7 @@ context = ServerContext.from_environ()
check_paths(context)
load_models(context)
load_params(context)
load_platforms()
load_platforms(context)
app = Flask(__name__)
CORS(app, origins=context.cors_origin)

View File

@ -19,6 +19,7 @@ class ServerContext:
params_path: str = ".",
cors_origin: str = "*",
num_workers: int = 1,
any_platform: bool = True,
block_platforms: List[str] = [],
default_platform: str = None,
image_format: str = "png",
@ -29,6 +30,7 @@ class ServerContext:
self.params_path = params_path
self.cors_origin = cors_origin
self.num_workers = num_workers
self.any_platform = any_platform
self.block_platforms = block_platforms
self.default_platform = default_platform
self.image_format = image_format
@ -45,6 +47,7 @@ class ServerContext:
# others
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)),
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True),
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
@ -57,7 +60,13 @@ def base_join(base: str, tail: str) -> str:
def is_debug() -> bool:
return environ.get("DEBUG") is not None
return get_boolean(environ, "DEBUG", False)
def get_boolean(
args: Any, key: str, default_value: bool
) -> bool:
return (args.get(key, str(default_value)).lower() in ('1', 't', 'true', 'y', 'yes'))
def get_and_clamp_float(