1
0
Fork 0

feat(api): set default platform for each container (fixes #82)

This commit is contained in:
Sean Sube 2023-01-31 08:45:06 -06:00
parent 1ca7edb4a8
commit f4fc6271bc
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 18 additions and 1 deletions

View File

@ -30,4 +30,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
COPY gui/ /onnx-web/gui/
ENV ONNX_WEB_DEFAULT_PLATFORM="cpu"
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]

View File

@ -34,4 +34,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
COPY gui/ /onnx-web/gui/
ENV ONNX_WEB_DEFAULT_PLATFORM="cuda"
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]

View File

@ -30,4 +30,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
COPY gui/ /onnx-web/gui/
ENV ONNX_WEB_DEFAULT_PLATFORM="directml"
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]

View File

@ -36,4 +36,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
COPY gui/ /onnx-web/gui/
ENV ONNX_WEB_DEFAULT_PLATFORM="rocm"
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]

View File

@ -293,6 +293,11 @@ def load_params(context: ServerContext):
with open(params_file, 'r') as f:
config_params = yaml.safe_load(f)
if 'platform' in config_params and context.default_platform is not None:
logger.info('overriding default platform to %s', context.default_platform)
config_platform = config_params.get('platform')
config_platform['default'] = context.default_platform
def load_platforms():
global available_platforms

View File

@ -25,6 +25,7 @@ class ServerContext:
cors_origin: str = '*',
num_workers: int = 1,
block_platforms: List[str] = [],
default_platform: str = None,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -33,6 +34,7 @@ class ServerContext:
self.cors_origin = cors_origin
self.num_workers = num_workers
self.block_platforms = block_platforms
self.default_platform = default_platform
@classmethod
def from_environ(cls):
@ -48,7 +50,9 @@ class ServerContext:
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
block_platforms=environ.get(
'ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
default_platform=environ.get(
'ONNX_WEB_DEFAULT_PLATFORM', None),
)