feat(api): make any platform optional
This commit is contained in:
parent
d7383d1101
commit
9551e4a9b9
|
@ -333,15 +333,18 @@ def load_params(context: ServerContext):
|
||||||
config_params = yaml.safe_load(f)
|
config_params = yaml.safe_load(f)
|
||||||
|
|
||||||
if "platform" in config_params and context.default_platform is not None:
|
if "platform" in config_params and context.default_platform is not None:
|
||||||
logger.info("overriding default platform to %s", context.default_platform)
|
logger.info("Overriding default platform from environment: %s", context.default_platform)
|
||||||
config_platform = config_params.get("platform", {})
|
config_platform = config_params.get("platform", {})
|
||||||
config_platform["default"] = context.default_platform
|
config_platform["default"] = context.default_platform
|
||||||
|
|
||||||
|
|
||||||
def load_platforms():
|
def load_platforms(context: ServerContext):
|
||||||
global available_platforms
|
global available_platforms
|
||||||
|
|
||||||
providers = ["any"].extend(get_available_providers())
|
providers = [].extend(get_available_providers())
|
||||||
|
|
||||||
|
if context.any_platform:
|
||||||
|
providers.append("any")
|
||||||
|
|
||||||
for potential in platform_providers:
|
for potential in platform_providers:
|
||||||
if (
|
if (
|
||||||
|
@ -391,7 +394,7 @@ context = ServerContext.from_environ()
|
||||||
check_paths(context)
|
check_paths(context)
|
||||||
load_models(context)
|
load_models(context)
|
||||||
load_params(context)
|
load_params(context)
|
||||||
load_platforms()
|
load_platforms(context)
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app, origins=context.cors_origin)
|
CORS(app, origins=context.cors_origin)
|
||||||
|
|
|
@ -19,6 +19,7 @@ class ServerContext:
|
||||||
params_path: str = ".",
|
params_path: str = ".",
|
||||||
cors_origin: str = "*",
|
cors_origin: str = "*",
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
|
any_platform: bool = True,
|
||||||
block_platforms: List[str] = [],
|
block_platforms: List[str] = [],
|
||||||
default_platform: str = None,
|
default_platform: str = None,
|
||||||
image_format: str = "png",
|
image_format: str = "png",
|
||||||
|
@ -29,6 +30,7 @@ class ServerContext:
|
||||||
self.params_path = params_path
|
self.params_path = params_path
|
||||||
self.cors_origin = cors_origin
|
self.cors_origin = cors_origin
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
self.any_platform = any_platform
|
||||||
self.block_platforms = block_platforms
|
self.block_platforms = block_platforms
|
||||||
self.default_platform = default_platform
|
self.default_platform = default_platform
|
||||||
self.image_format = image_format
|
self.image_format = image_format
|
||||||
|
@ -45,6 +47,7 @@ class ServerContext:
|
||||||
# others
|
# others
|
||||||
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
|
||||||
num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)),
|
num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)),
|
||||||
|
any_platform=get_boolean(environ, "ONNX_WEB_ANY_PLATFORM", True),
|
||||||
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
|
||||||
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
|
||||||
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
|
||||||
|
@ -57,7 +60,13 @@ def base_join(base: str, tail: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
def is_debug() -> bool:
|
def is_debug() -> bool:
|
||||||
return environ.get("DEBUG") is not None
|
return get_boolean(environ, "DEBUG", False)
|
||||||
|
|
||||||
|
|
||||||
|
def get_boolean(
|
||||||
|
args: Any, key: str, default_value: bool
|
||||||
|
) -> bool:
|
||||||
|
return (args.get(key, str(default_value)).lower() in ('1', 't', 'true', 'y', 'yes'))
|
||||||
|
|
||||||
|
|
||||||
def get_and_clamp_float(
|
def get_and_clamp_float(
|
||||||
|
|
Loading…
Reference in New Issue