1
0
Fork 0
onnx-web/api/onnx_web/utils.py

254 lines
6.5 KiB
Python
Raw Normal View History

2023-02-05 13:53:26 +00:00
import gc
import importlib
import json
2023-04-24 23:10:12 +00:00
import threading
from hashlib import sha256
from json import JSONDecodeError
2023-01-28 23:09:19 +00:00
from logging import getLogger
from os import environ, path
from platform import system
from struct import pack
2023-07-04 17:09:46 +00:00
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
2023-01-16 00:54:20 +00:00
import torch
from yaml import safe_load
from .params import DeviceParams, Param, SizeChart
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
SAFE_CHARS = "._-"
def split_list(val: str) -> List[str]:
parts = [part.strip() for part in val.split(",")]
2023-12-03 18:13:45 +00:00
return [part for part in parts if len(part) > 0]
2023-02-02 14:31:35 +00:00
def base_join(base: str, tail: str) -> str:
2023-02-05 13:53:26 +00:00
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
2023-02-02 14:31:35 +00:00
return path.join(base, tail_path)
def is_debug() -> bool:
2023-02-11 21:53:27 +00:00
return get_boolean(environ, "DEBUG", False)
2023-02-11 22:50:57 +00:00
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
2023-09-11 02:25:23 +00:00
val = args.get(key, str(default_value))
2023-09-12 01:47:03 +00:00
if isinstance(val, bool):
2023-09-11 02:25:23 +00:00
return val
return val.lower() in ("1", "t", "true", "y", "yes")
2023-11-25 13:50:36 +00:00
def get_list(args: Any, key: str, default="") -> List[str]:
2023-11-25 05:02:22 +00:00
return split_list(args.get(key, default))
2023-02-05 13:53:26 +00:00
def get_and_clamp_float(
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
) -> float:
return min(max(float(args.get(key, default_value)), min_value), max_value)
2023-02-05 13:53:26 +00:00
def get_and_clamp_int(
args: Any, key: str, default_value: int, max_value: int, min_value=1
) -> int:
return min(max(int(args.get(key, default_value)), min_value), max_value)
TElem = TypeVar("TElem")
2023-07-04 17:09:46 +00:00
2023-04-13 23:41:22 +00:00
def get_from_list(
args: Any, key: str, values: Sequence[TElem], default_value: Optional[TElem] = None
) -> Optional[TElem]:
2023-04-13 04:26:16 +00:00
selected = args.get(key, default_value)
2023-01-17 02:10:52 +00:00
if selected in values:
return selected
logger.warning("invalid selection %s, options: %s", selected, values)
if len(values) > 0:
2023-01-17 02:10:52 +00:00
return values[0]
return None
2023-01-17 02:10:52 +00:00
2023-07-04 17:09:46 +00:00
def get_from_map(
args: Any, key: str, values: Dict[str, TElem], default_key: str
2023-07-04 17:09:46 +00:00
) -> TElem:
selected = args.get(key, default_key)
if selected in values:
return values[selected]
else:
return values[default_key]
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
val = args.get(key, default)
if val is None or len(val) == 0:
val = default
return val
2023-02-20 04:10:35 +00:00
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
2023-01-29 05:06:25 +00:00
if val is None:
return SizeChart.auto
if type(val) is int:
return val
if type(val) is str:
for size in SizeChart:
if val == size.name:
return size
return int(val)
2023-02-20 04:10:35 +00:00
raise ValueError("invalid size")
2023-01-29 05:06:25 +00:00
def run_gc(devices: Optional[List[DeviceParams]] = None):
2023-02-18 04:49:13 +00:00
logger.debug(
"running garbage collection with %s active threads", threading.active_count()
)
gc.collect()
if torch.cuda.is_available() and devices is not None:
for device in [d for d in devices if d.device.startswith("cuda")]:
logger.debug("running Torch garbage collection for device: %s", device)
with torch.cuda.device(device.torch_str()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
mem_free, mem_total = torch.cuda.mem_get_info()
2023-03-06 01:23:23 +00:00
mem_pct = (1 - (mem_free / mem_total)) * 100
logger.debug(
2023-03-06 01:23:23 +00:00
"CUDA VRAM usage: %s of %s (%.2f%%)",
(mem_total - mem_free),
mem_total,
2023-03-06 01:23:23 +00:00
mem_pct,
)
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
def merge(a, b, path=None):
"merges b into a"
2023-03-05 13:19:48 +00:00
if path is None:
path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
merge(a[key], b[key], path + [str(key)])
elif a[key] == b[key]:
2023-03-05 13:19:48 +00:00
pass # same leaf value
else:
2023-03-21 22:11:38 +00:00
raise ValueError("conflict at %s" % ".".join(path + [str(key)]))
else:
a[key] = b[key]
return a
toaster = None
2023-04-24 23:10:12 +00:00
def show_system_toast(msg: str) -> None:
global toaster
sys_name = system()
if sys_name == "Linux":
if (
importlib.util.find_spec("gi") is not None
and importlib.util.find_spec("gi.repository") is not None
):
from gi.repository import Notify
2023-04-24 23:10:12 +00:00
if toaster is None:
Notify.init("onnx-web")
Notify.Notification.new(msg).show()
else:
2023-04-24 23:10:12 +00:00
logger.info(
"please install the PyGObject module to enable toast notifications on Linux"
)
elif sys_name == "Windows":
if importlib.util.find_spec("win10toast") is not None:
from win10toast import ToastNotifier
2023-04-24 23:10:12 +00:00
if toaster is None:
toaster = ToastNotifier()
toaster.show_toast(msg, duration=15)
else:
2023-04-24 23:10:12 +00:00
logger.info(
"please install the win10toast module to enable toast notifications on Windows"
)
else:
logger.info("system notifications not yet available for %s", sys_name)
def load_json(file: str) -> Dict:
with open(file, "r") as f:
data = json.loads(f.read())
return data
def load_yaml(file: str) -> Dict:
with open(file, "r") as f:
data = safe_load(f.read())
return data
def load_config(file: str) -> Dict:
name, ext = path.splitext(file)
if ext in [".yml", ".yaml"]:
return load_yaml(file)
elif ext in [".json"]:
return load_json(file)
2023-12-03 18:13:45 +00:00
else:
raise ValueError("unknown config file extension")
def load_config_str(raw: str) -> Dict:
try:
return json.loads(raw)
except JSONDecodeError:
return safe_load(raw)
HASH_BUFFER_SIZE = 2**22 # 4MB
def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break
sha.update(data)
return sha.hexdigest()
def hash_value(sha, param: Optional[Param]):
if param is None:
return
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
sha.update(bytearray(pack("!I", param)))
elif isinstance(param, str):
sha.update(param.encode("utf-8"))
else:
logger.warning("cannot hash param: %s, %s", param, type(param))