feat(api): set default platform for each container (fixes #82)
This commit is contained in:
parent
1ca7edb4a8
commit
f4fc6271bc
|
@ -30,4 +30,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
|
||||||
|
|
||||||
COPY gui/ /onnx-web/gui/
|
COPY gui/ /onnx-web/gui/
|
||||||
|
|
||||||
|
ENV ONNX_WEB_DEFAULT_PLATFORM="cpu"
|
||||||
|
|
||||||
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
||||||
|
|
|
@ -34,4 +34,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
|
||||||
|
|
||||||
COPY gui/ /onnx-web/gui/
|
COPY gui/ /onnx-web/gui/
|
||||||
|
|
||||||
|
ENV ONNX_WEB_DEFAULT_PLATFORM="cuda"
|
||||||
|
|
||||||
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
||||||
|
|
|
@ -30,4 +30,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
|
||||||
|
|
||||||
COPY gui/ /onnx-web/gui/
|
COPY gui/ /onnx-web/gui/
|
||||||
|
|
||||||
|
ENV ONNX_WEB_DEFAULT_PLATFORM="directml"
|
||||||
|
|
||||||
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
||||||
|
|
|
@ -36,4 +36,6 @@ ENV ONNX_WEB_BUNDLE_PATH="/onnx-web/gui"
|
||||||
|
|
||||||
COPY gui/ /onnx-web/gui/
|
COPY gui/ /onnx-web/gui/
|
||||||
|
|
||||||
|
ENV ONNX_WEB_DEFAULT_PLATFORM="rocm"
|
||||||
|
|
||||||
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
CMD [ "sh", "-c", "/onnx-web/launch.sh" ]
|
||||||
|
|
|
@ -293,6 +293,11 @@ def load_params(context: ServerContext):
|
||||||
with open(params_file, 'r') as f:
|
with open(params_file, 'r') as f:
|
||||||
config_params = yaml.safe_load(f)
|
config_params = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if 'platform' in config_params and context.default_platform is not None:
|
||||||
|
logger.info('overriding default platform to %s', context.default_platform)
|
||||||
|
config_platform = config_params.get('platform')
|
||||||
|
config_platform['default'] = context.default_platform
|
||||||
|
|
||||||
|
|
||||||
def load_platforms():
|
def load_platforms():
|
||||||
global available_platforms
|
global available_platforms
|
||||||
|
|
|
@ -25,6 +25,7 @@ class ServerContext:
|
||||||
cors_origin: str = '*',
|
cors_origin: str = '*',
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
block_platforms: List[str] = [],
|
block_platforms: List[str] = [],
|
||||||
|
default_platform: str = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -33,6 +34,7 @@ class ServerContext:
|
||||||
self.cors_origin = cors_origin
|
self.cors_origin = cors_origin
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.block_platforms = block_platforms
|
self.block_platforms = block_platforms
|
||||||
|
self.default_platform = default_platform
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_environ(cls):
|
def from_environ(cls):
|
||||||
|
@ -48,7 +50,9 @@ class ServerContext:
|
||||||
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
|
||||||
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
|
||||||
block_platforms=environ.get(
|
block_platforms=environ.get(
|
||||||
'ONNX_WEB_BLOCK_PLATFORMS', '').split(',')
|
'ONNX_WEB_BLOCK_PLATFORMS', '').split(','),
|
||||||
|
default_platform=environ.get(
|
||||||
|
'ONNX_WEB_DEFAULT_PLATFORM', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue