1
0
Fork 0
onnx-web/api/onnx_web/utils.py

125 lines
3.4 KiB
Python
Raw Normal View History

2023-02-05 13:53:26 +00:00
import gc
2023-01-28 23:09:19 +00:00
from logging import getLogger
from os import environ, path
2023-02-02 14:31:35 +00:00
from typing import Any, Dict, List, Optional, Union
2023-01-16 00:54:20 +00:00
import torch
2023-02-05 13:53:26 +00:00
from .params import SizeChart
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
class ServerContext:
def __init__(
self,
2023-02-05 13:53:26 +00:00
bundle_path: str = ".",
model_path: str = ".",
output_path: str = ".",
params_path: str = ".",
cors_origin: str = "*",
num_workers: int = 1,
block_platforms: List[str] = [],
default_platform: str = None,
2023-02-05 13:53:26 +00:00
image_format: str = "png",
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
self.output_path = output_path
self.params_path = params_path
self.cors_origin = cors_origin
self.num_workers = num_workers
self.block_platforms = block_platforms
self.default_platform = default_platform
self.image_format = image_format
@classmethod
2023-01-16 22:40:59 +00:00
def from_environ(cls):
return ServerContext(
2023-02-05 13:53:26 +00:00
bundle_path=environ.get(
"ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out")
),
2023-02-05 13:53:26 +00:00
model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")),
output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")),
params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."),
# others
cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","),
num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)),
block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","),
default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None),
image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"),
)
2023-02-02 14:31:35 +00:00
def base_join(base: str, tail: str) -> str:
2023-02-05 13:53:26 +00:00
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
2023-02-02 14:31:35 +00:00
return path.join(base, tail_path)
def is_debug() -> bool:
2023-02-05 13:53:26 +00:00
return environ.get("DEBUG") is not None
2023-02-05 13:53:26 +00:00
def get_and_clamp_float(
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
2023-02-05 13:53:26 +00:00
def get_and_clamp_int(
args: Any, key: str, default_value: int, max_value: int, min_value=1
) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value)
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
selected = args.get(key, None)
2023-01-17 02:10:52 +00:00
if selected in values:
return selected
2023-02-05 13:53:26 +00:00
logger.warn("invalid selection: %s", selected)
if len(values) > 0:
2023-01-17 02:10:52 +00:00
return values[0]
return None
2023-01-17 02:10:52 +00:00
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any:
selected = args.get(key, default)
if selected in values:
return values[selected]
else:
return values[default]
def get_not_empty(args: Any, key: str, default: Any) -> Any:
val = args.get(key, default)
if val is None or len(val) == 0:
val = default
return val
2023-01-29 05:06:25 +00:00
def get_size(val: Union[int, str, None]) -> SizeChart:
if val is None:
return SizeChart.auto
if type(val) is int:
return val
if type(val) is str:
for size in SizeChart:
if val == size.name:
return size
return int(val)
2023-02-05 13:53:26 +00:00
raise Exception("invalid size")
2023-01-29 05:06:25 +00:00
def run_gc():
2023-02-05 13:53:26 +00:00
logger.debug("running garbage collection")
gc.collect()
2023-02-02 14:31:35 +00:00
torch.cuda.empty_cache()