1
0
Fork 0

fix circular imports

This commit is contained in:
Sean Sube 2024-01-30 09:55:35 -06:00
parent e09b1ce6b8
commit 4b5ff69f98
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 27 additions and 44 deletions

View File

@ -591,3 +591,20 @@ class HighresParams:
method=coalesce(method, self.method),
iterations=coalesce(iterations, self.iterations),
)
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
if val is None:
return SizeChart.auto
if type(val) is int:
return val
if type(val) is str:
for size in SizeChart:
if val == size.name:
return size
return int(val)
raise ValueError("invalid size")

View File

@ -1,20 +1,15 @@
import gc
import importlib
import json
import threading
from hashlib import sha256
from json import JSONDecodeError
from logging import getLogger
from os import environ, path
from platform import system
from struct import pack
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
from typing import Any, Dict, List, Optional, Sequence, TypeVar
import torch
from yaml import safe_load
from .params import DeviceParams, Param, SizeChart
logger = getLogger(__name__)
SAFE_CHARS = "._-"
@ -95,43 +90,14 @@ def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
return val
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
if val is None:
return SizeChart.auto
def run_gc(devices: Optional[List[Any]] = None):
"""
Deprecated, use `onnx_web.device.run_gc` instead.
"""
from .device import run_gc as run_gc_impl
if type(val) is int:
return val
if type(val) is str:
for size in SizeChart:
if val == size.name:
return size
return int(val)
raise ValueError("invalid size")
def run_gc(devices: Optional[List[DeviceParams]] = None):
logger.debug(
"running garbage collection with %s active threads", threading.active_count()
)
gc.collect()
if torch.cuda.is_available() and devices is not None:
for device in [d for d in devices if d.device.startswith("cuda")]:
logger.debug("running Torch garbage collection for device: %s", device)
with torch.cuda.device(device.torch_str()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
mem_free, mem_total = torch.cuda.mem_get_info()
mem_pct = (1 - (mem_free / mem_total)) * 100
logger.debug(
"CUDA VRAM usage: %s of %s (%.2f%%)",
(mem_total - mem_free),
mem_total,
mem_pct,
)
logger.debug("calling deprecated run_gc, please use onnx_web.device.run_gc instead")
run_gc_impl(devices)
def sanitize_name(name):
@ -238,9 +204,9 @@ def hash_file(name: str):
return sha.hexdigest()
def hash_value(sha, param: Optional[Param]):
def hash_value(sha, param: Optional[Any]):
if param is None:
return
return None
elif isinstance(param, bool):
sha.update(bytearray(pack("!B", param)))
elif isinstance(param, float):