1
0
Fork 0

feat(api): add a way for the server to disable certain platforms (#83)

This commit is contained in:
Sean Sube 2023-01-24 23:19:57 -06:00
parent 43787f085e
commit 67d51a96e3
2 changed files with 4 additions and 1 deletions

View File

@ -263,7 +263,7 @@ def load_platforms():
providers = get_available_providers()
available_platforms = [p for p in platform_providers if (
platform_providers[p] in providers)]
platform_providers[p] in providers and p not in context.block_platforms)]
print('available acceleration platforms: %s' % (available_platforms))

View File

@ -60,6 +60,7 @@ class ServerContext:
params_path: str = '.',
cors_origin: str = '*',
num_workers: int = 1,
block_platforms: List[str] = [],
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -67,6 +68,7 @@ class ServerContext:
self.params_path = params_path
self.cors_origin = cors_origin
self.num_workers = num_workers
self.block_platforms = block_platforms
@classmethod
def from_environ(cls):
@ -81,6 +83,7 @@ class ServerContext:
# others
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
block_platforms=environ.get('ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
)