From 4b5ff69f985a2c16e3224d10095e9297d81b34c8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Tue, 30 Jan 2024 09:55:35 -0600 Subject: [PATCH] fix circular imports --- api/onnx_web/params.py | 17 +++++++++++++ api/onnx_web/utils.py | 54 ++++++++---------------------------------- 2 files changed, 27 insertions(+), 44 deletions(-) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index f0bf7b7f..0db05d26 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -591,3 +591,20 @@ class HighresParams: method=coalesce(method, self.method), iterations=coalesce(iterations, self.iterations), ) + + +def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]: + if val is None: + return SizeChart.auto + + if type(val) is int: + return val + + if type(val) is str: + for size in SizeChart: + if val == size.name: + return size + + return int(val) + + raise ValueError("invalid size") diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 3db2ad2e..a888133d 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -1,20 +1,15 @@ -import gc import importlib import json -import threading from hashlib import sha256 from json import JSONDecodeError from logging import getLogger from os import environ, path from platform import system from struct import pack -from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union +from typing import Any, Dict, List, Optional, Sequence, TypeVar -import torch from yaml import safe_load -from .params import DeviceParams, Param, SizeChart - logger = getLogger(__name__) SAFE_CHARS = "._-" @@ -95,43 +90,14 @@ def get_not_empty(args: Any, key: str, default: TElem) -> TElem: return val -def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]: - if val is None: - return SizeChart.auto +def run_gc(devices: Optional[List[Any]] = None): + """ + Deprecated, use `onnx_web.device.run_gc` instead. + """ + from .device import run_gc as run_gc_impl - if type(val) is int: - return val - - if type(val) is str: - for size in SizeChart: - if val == size.name: - return size - - return int(val) - - raise ValueError("invalid size") - - -def run_gc(devices: Optional[List[DeviceParams]] = None): - logger.debug( - "running garbage collection with %s active threads", threading.active_count() - ) - gc.collect() - - if torch.cuda.is_available() and devices is not None: - for device in [d for d in devices if d.device.startswith("cuda")]: - logger.debug("running Torch garbage collection for device: %s", device) - with torch.cuda.device(device.torch_str()): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - mem_free, mem_total = torch.cuda.mem_get_info() - mem_pct = (1 - (mem_free / mem_total)) * 100 - logger.debug( - "CUDA VRAM usage: %s of %s (%.2f%%)", - (mem_total - mem_free), - mem_total, - mem_pct, - ) + logger.debug("calling deprecated run_gc, please use onnx_web.device.run_gc instead") + run_gc_impl(devices) def sanitize_name(name): @@ -238,9 +204,9 @@ def hash_file(name: str): return sha.hexdigest() -def hash_value(sha, param: Optional[Param]): +def hash_value(sha, param: Optional[Any]): if param is None: - return + return None elif isinstance(param, bool): sha.update(bytearray(pack("!B", param))) elif isinstance(param, float):