diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index bfeabb00..08c83caa 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -304,7 +304,10 @@ def main() -> int: args = parser.parse_args() logger.info("CLI arguments: %s", args) - ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token) + ctx = ConversionContext.from_environ() + ctx.half = args.half + ctx.opset = args.opset + ctx.token = args.token logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device) if ctx.half and ctx.training_device != "cuda": diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 0f63230d..f7c41cb5 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -24,11 +24,12 @@ class ConversionContext(ServerContext): def __init__( self, model_path: Optional[str] = None, - device: Optional[str] = None, cache_path: Optional[str] = None, + device: Optional[str] = None, half: Optional[bool] = False, opset: Optional[int] = None, token: Optional[str] = None, + **kwargs, ) -> None: super().__init__(self, model_path=model_path, cache_path=cache_path) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 8c2bb24e..59860639 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -48,7 +48,7 @@ class ServerContext: num_workers = int(environ.get("ONNX_WEB_NUM_WORKERS", 1)) cache_limit = int(environ.get("ONNX_WEB_CACHE_MODELS", num_workers + 2)) - return ServerContext( + return cls( bundle_path=environ.get( "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") ),