fix circular imports
This commit is contained in:
parent
e09b1ce6b8
commit
4b5ff69f98
|
@ -591,3 +591,20 @@ class HighresParams:
|
||||||
method=coalesce(method, self.method),
|
method=coalesce(method, self.method),
|
||||||
iterations=coalesce(iterations, self.iterations),
|
iterations=coalesce(iterations, self.iterations),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
|
||||||
|
if val is None:
|
||||||
|
return SizeChart.auto
|
||||||
|
|
||||||
|
if type(val) is int:
|
||||||
|
return val
|
||||||
|
|
||||||
|
if type(val) is str:
|
||||||
|
for size in SizeChart:
|
||||||
|
if val == size.name:
|
||||||
|
return size
|
||||||
|
|
||||||
|
return int(val)
|
||||||
|
|
||||||
|
raise ValueError("invalid size")
|
||||||
|
|
|
@ -1,20 +1,15 @@
|
||||||
import gc
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import threading
|
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
from platform import system
|
from platform import system
|
||||||
from struct import pack
|
from struct import pack
|
||||||
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
|
from typing import Any, Dict, List, Optional, Sequence, TypeVar
|
||||||
|
|
||||||
import torch
|
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
from .params import DeviceParams, Param, SizeChart
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
SAFE_CHARS = "._-"
|
SAFE_CHARS = "._-"
|
||||||
|
@ -95,43 +90,14 @@ def get_not_empty(args: Any, key: str, default: TElem) -> TElem:
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
|
def run_gc(devices: Optional[List[Any]] = None):
|
||||||
if val is None:
|
"""
|
||||||
return SizeChart.auto
|
Deprecated, use `onnx_web.device.run_gc` instead.
|
||||||
|
"""
|
||||||
|
from .device import run_gc as run_gc_impl
|
||||||
|
|
||||||
if type(val) is int:
|
logger.debug("calling deprecated run_gc, please use onnx_web.device.run_gc instead")
|
||||||
return val
|
run_gc_impl(devices)
|
||||||
|
|
||||||
if type(val) is str:
|
|
||||||
for size in SizeChart:
|
|
||||||
if val == size.name:
|
|
||||||
return size
|
|
||||||
|
|
||||||
return int(val)
|
|
||||||
|
|
||||||
raise ValueError("invalid size")
|
|
||||||
|
|
||||||
|
|
||||||
def run_gc(devices: Optional[List[DeviceParams]] = None):
|
|
||||||
logger.debug(
|
|
||||||
"running garbage collection with %s active threads", threading.active_count()
|
|
||||||
)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
if torch.cuda.is_available() and devices is not None:
|
|
||||||
for device in [d for d in devices if d.device.startswith("cuda")]:
|
|
||||||
logger.debug("running Torch garbage collection for device: %s", device)
|
|
||||||
with torch.cuda.device(device.torch_str()):
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.ipc_collect()
|
|
||||||
mem_free, mem_total = torch.cuda.mem_get_info()
|
|
||||||
mem_pct = (1 - (mem_free / mem_total)) * 100
|
|
||||||
logger.debug(
|
|
||||||
"CUDA VRAM usage: %s of %s (%.2f%%)",
|
|
||||||
(mem_total - mem_free),
|
|
||||||
mem_total,
|
|
||||||
mem_pct,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sanitize_name(name):
|
def sanitize_name(name):
|
||||||
|
@ -238,9 +204,9 @@ def hash_file(name: str):
|
||||||
return sha.hexdigest()
|
return sha.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def hash_value(sha, param: Optional[Param]):
|
def hash_value(sha, param: Optional[Any]):
|
||||||
if param is None:
|
if param is None:
|
||||||
return
|
return None
|
||||||
elif isinstance(param, bool):
|
elif isinstance(param, bool):
|
||||||
sha.update(bytearray(pack("!B", param)))
|
sha.update(bytearray(pack("!B", param)))
|
||||||
elif isinstance(param, float):
|
elif isinstance(param, float):
|
||||||
|
|
Loading…
Reference in New Issue