feat(api): add a way for the server to disable certain platforms (#83)
This commit is contained in:
parent
43787f085e
commit
67d51a96e3
|
@ -263,7 +263,7 @@ def load_platforms():
|
|||
|
||||
providers = get_available_providers()
|
||||
available_platforms = [p for p in platform_providers if (
|
||||
platform_providers[p] in providers)]
|
||||
platform_providers[p] in providers and p not in context.block_platforms)]
|
||||
|
||||
print('available acceleration platforms: %s' % (available_platforms))
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class ServerContext:
|
|||
params_path: str = '.',
|
||||
cors_origin: str = '*',
|
||||
num_workers: int = 1,
|
||||
block_platforms: List[str] = [],
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -67,6 +68,7 @@ class ServerContext:
|
|||
self.params_path = params_path
|
||||
self.cors_origin = cors_origin
|
||||
self.num_workers = num_workers
|
||||
self.block_platforms = block_platforms
|
||||
|
||||
@classmethod
|
||||
def from_environ(cls):
|
||||
|
@ -81,6 +83,7 @@ class ServerContext:
|
|||
# others
|
||||
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
||||
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
||||
block_platforms=environ.get('ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue