fix(api): create conversion context from environment vars
This commit is contained in:
parent
b4e66ef502
commit
15060e6c7d
|
@ -304,7 +304,10 @@ def main() -> int:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger.info("CLI arguments: %s", args)
|
logger.info("CLI arguments: %s", args)
|
||||||
|
|
||||||
ctx = ConversionContext(half=args.half, opset=args.opset, token=args.token)
|
ctx = ConversionContext.from_environ()
|
||||||
|
ctx.half = args.half
|
||||||
|
ctx.opset = args.opset
|
||||||
|
ctx.token = args.token
|
||||||
logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device)
|
logger.info("converting models in %s using %s", ctx.model_path, ctx.training_device)
|
||||||
|
|
||||||
if ctx.half and ctx.training_device != "cuda":
|
if ctx.half and ctx.training_device != "cuda":
|
||||||
|
|
|
@ -24,11 +24,12 @@ class ConversionContext(ServerContext):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_path: Optional[str] = None,
|
model_path: Optional[str] = None,
|
||||||
device: Optional[str] = None,
|
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
half: Optional[bool] = False,
|
half: Optional[bool] = False,
|
||||||
opset: Optional[int] = None,
|
opset: Optional[int] = None,
|
||||||
token: Optional[str] = None,
|
token: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(self, model_path=model_path, cache_path=cache_path)
|
super().__init__(self, model_path=model_path, cache_path=cache_path)
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ class ServerContext:
|
||||||
num_workers = int(environ.get("ONNX_WEB_NUM_WORKERS", 1))
|
num_workers = int(environ.get("ONNX_WEB_NUM_WORKERS", 1))
|
||||||
cache_limit = int(environ.get("ONNX_WEB_CACHE_MODELS", num_workers + 2))
|
cache_limit = int(environ.get("ONNX_WEB_CACHE_MODELS", num_workers + 2))
|
||||||
|
|
||||||
return ServerContext(
|
return cls(
|
||||||
bundle_path=environ.get(
|
bundle_path=environ.get(
|
||||||
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
|
||||||
),
|
),
|
||||||
|
|
Loading…
Reference in New Issue