1
0
Fork 0
onnx-web/api/onnx_web/utils.py

245 lines
6.1 KiB
Python

import importlib
import json
from functools import reduce
from hashlib import sha256
from json import JSONDecodeError
from logging import getLogger
from os import environ, path
from platform import system
from struct import pack
from typing import Any, Dict, List, Optional, Sequence, TypeVar
from yaml import safe_load
logger = getLogger(__name__)
SAFE_CHARS = "._-"
def split_list(val: str) -> List[str]:
parts = [part.strip() for part in val.split(",")]
return [part for part in parts if len(part) > 0]
def base_join(base: str, tail: str) -> str:
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
return path.join(base, tail_path)
def is_debug() -> bool:
return get_boolean(environ, "DEBUG", False)
def recursive_get(d, keys, default_value=None):
empty_dict = {}
val = reduce(lambda c, k: c.get(k, empty_dict), keys, d)
if val == empty_dict:
return default_value
return val
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
val = recursive_get(args, key.split("."), default_value=str(default_value))
if isinstance(val, bool):
return val
return val.lower() in ("1", "t", "true", "y", "yes")
def get_list(args: Any, key: str, default="") -> List[str]:
val = recursive_get(args, key.split("."), default_value=default)
return split_list(val)
def get_and_clamp_float(
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
) -> float:
val = recursive_get(args, key.split("."), default_value=default_value)
return min(max(float(val), min_value), max_value)
def get_and_clamp_int(
args: Any, key: str, default_value: int, max_value: int, min_value=1
) -> int:
val = recursive_get(args, key.split("."), default_value=default_value)
return min(max(int(val), min_value), max_value)
TElem = TypeVar("TElem")
def get_from_list(
args: Any, key: str, values: Sequence[TElem], default_value: Optional[TElem] = None
) -> Optional[TElem]:
selected = args.get(key, default_value)
if selected in values:
return selected
logger.warning("invalid selection %s, options: %s", selected, values)
if len(values) > 0:
return values[0]
return None
def get_from_map(
args: Any, key: str, values: Dict[str, TElem], default_key: str
) -> TElem:
selected = args.get(key, default_key)
if selected in values:
return values[selected]
else:
return values[default_key]
def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
val = args.get(key, default)
if val is None or len(val) == 0:
val = default
return val
def run_gc(devices: Optional[List[Any]] = None):
"""
Deprecated, use `onnx_web.device.run_gc` instead.
"""
from .device import run_gc as run_gc_impl
logger.debug("calling deprecated run_gc, please use onnx_web.device.run_gc instead")
run_gc_impl(devices)
def sanitize_name(name):
return "".join(x for x in name if (x.isalnum() or x in SAFE_CHARS))
def merge(a, b, path=None):
"merges b into a"
if path is None:
path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
merge(a[key], b[key], path + [str(key)])
elif a[key] == b[key]:
pass # same leaf value
else:
raise ValueError("conflict at %s" % ".".join(path + [str(key)]))
else:
a[key] = b[key]
return a
toaster = None
def show_system_toast(msg: str) -> None:
global toaster
sys_name = system()
if sys_name == "Linux":
if (
importlib.util.find_spec("gi") is not None
and importlib.util.find_spec("gi.repository") is not None
):
from gi.repository import Notify
if toaster is None:
Notify.init("onnx-web")
Notify.Notification.new(msg).show()
else:
logger.info(
"please install the PyGObject module to enable toast notifications on Linux"
)
elif sys_name == "Windows":
if importlib.util.find_spec("win10toast") is not None:
from win10toast import ToastNotifier
if toaster is None:
toaster = ToastNotifier()
toaster.show_toast(msg, duration=15)
else:
logger.info(
"please install the win10toast module to enable toast notifications on Windows"
)
else:
logger.info("system notifications not yet available for %s", sys_name)
def load_json(file: str) -> Dict:
with open(file, "r") as f:
data = json.loads(f.read())
return data
def load_yaml(file: str) -> Dict:
with open(file, "r") as f:
data = safe_load(f.read())
return data
def load_config(file: str) -> Dict:
name, ext = path.splitext(file)
if ext in [".yml", ".yaml"]:
return load_yaml(file)
elif ext in [".json"]:
return load_json(file)
else:
raise ValueError("unknown config file extension")
def load_config_str(raw: str) -> Dict:
try:
return json.loads(raw)
except JSONDecodeError:
return safe_load(raw)
HASH_BUFFER_SIZE = 2**22 # 4MB
def hash_file(name: str):
sha = sha256()
with open(name, "rb") as f:
while True:
data = f.read(HASH_BUFFER_SIZE)
if not data:
break
sha.update(data)
return sha.hexdigest()
def hash_value(sha, param: Optional[Any]):
if param is None:
return None
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):
sha.update(bytearray(pack("!f", param)))
elif isinstance(param, int):
sha.update(bytearray(pack("!I", param)))
elif isinstance(param, str):
sha.update(param.encode("utf-8"))
else:
logger.warning("cannot hash param: %s, %s", param, type(param))
def coalesce(*args, throw=False):
for arg in args:
if arg is not None:
return arg
if throw:
raise ValueError("no value found")
return None