feat(api): set default platform for each container (fixes #82)
This commit is contained in:
parent
1ca7edb4a8
commit
f4fc6271bc
|
@ -30,4 +30,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
|
|||
|
||||
COPY gui/ /onnx-web/gui/
|
||||
|
||||
ENV ONNX_WEB_DEFAULT_PLATFORM="cpu"
|
||||
|
||||
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
||||
|
|
|
@ -34,4 +34,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
|
|||
|
||||
COPY gui/ /onnx-web/gui/
|
||||
|
||||
ENV ONNX_WEB_DEFAULT_PLATFORM="cuda"
|
||||
|
||||
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
||||
|
|
|
@ -30,4 +30,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
|
|||
|
||||
COPY gui/ /onnx-web/gui/
|
||||
|
||||
ENV ONNX_WEB_DEFAULT_PLATFORM="directml"
|
||||
|
||||
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
||||
|
|
|
@ -36,4 +36,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
|
|||
|
||||
COPY gui/ /onnx-web/gui/
|
||||
|
||||
ENV ONNX_WEB_DEFAULT_PLATFORM="rocm"
|
||||
|
||||
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
||||
|
|
|
@ -293,6 +293,11 @@ def load_params(context: ServerContext):
|
|||
with open(params_file, 'r') as f:
|
||||
config_params = yaml.safe_load(f)
|
||||
|
||||
if 'platform' in config_params and context.default_platform is not None:
|
||||
logger.info('overriding default platform to %s', context.default_platform)
|
||||
config_platform = config_params.get('platform')
|
||||
config_platform['default'] = context.default_platform
|
||||
|
||||
|
||||
def load_platforms():
|
||||
global available_platforms
|
||||
|
|
|
@ -25,6 +25,7 @@ class ServerContext:
|
|||
cors_origin: str = '*',
|
||||
num_workers: int = 1,
|
||||
block_platforms: List[str] = [],
|
||||
default_platform: str = None,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -33,6 +34,7 @@ class ServerContext:
|
|||
self.cors_origin = cors_origin
|
||||
self.num_workers = num_workers
|
||||
self.block_platforms = block_platforms
|
||||
self.default_platform = default_platform
|
||||
|
||||
@classmethod
|
||||
def from_environ(cls):
|
||||
|
@ -48,7 +50,9 @@ class ServerContext:
|
|||
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
||||
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
||||
block_platforms=environ.get(
|
||||
'ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
|
||||
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
|
||||
default_platform=environ.get(
|
||||
'ONNX_WEB_DEFAULT_PLATFORM', None),
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue