2023-02-05 13:53:26 +00:00
|
|
|
import gc
|
2023-02-18 04:49:13 +00:00
|
|
|
import threading
|
2023-01-28 23:09:19 +00:00
|
|
|
from logging import getLogger
|
2023-01-16 22:39:30 +00:00
|
|
|
from os import environ, path
|
2023-02-02 14:31:35 +00:00
|
|
|
from typing import Any, Dict, List, Optional, Union
|
2023-01-16 00:54:20 +00:00
|
|
|
|
2023-02-02 03:20:48 +00:00
|
|
|
import torch
|
|
|
|
|
2023-02-17 00:11:35 +00:00
|
|
|
from .params import DeviceParams, SizeChart
|
2023-01-16 01:14:58 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-01-16 01:14:58 +00:00
|
|
|
|
2023-02-02 14:31:35 +00:00
|
|
|
def base_join(base: str, tail: str) -> str:
|
2023-02-05 13:53:26 +00:00
|
|
|
tail_path = path.relpath(path.normpath(path.join("/", tail)), "/")
|
2023-02-02 14:31:35 +00:00
|
|
|
return path.join(base, tail_path)
|
|
|
|
|
|
|
|
|
2023-01-20 01:46:36 +00:00
|
|
|
def is_debug() -> bool:
|
2023-02-11 21:53:27 +00:00
|
|
|
return get_boolean(environ, "DEBUG", False)
|
|
|
|
|
|
|
|
|
2023-02-11 22:50:57 +00:00
|
|
|
def get_boolean(args: Any, key: str, default_value: bool) -> bool:
|
|
|
|
return args.get(key, str(default_value)).lower() in ("1", "t", "true", "y", "yes")
|
2023-01-20 01:46:36 +00:00
|
|
|
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
def get_and_clamp_float(
|
|
|
|
args: Any, key: str, default_value: float, max_value: float, min_value=0.0
|
|
|
|
) -> float:
|
2023-01-16 01:47:57 +00:00
|
|
|
return min(max(float(args.get(key, default_value)), min_value), max_value)
|
|
|
|
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
def get_and_clamp_int(
|
|
|
|
args: Any, key: str, default_value: int, max_value: int, min_value=1
|
|
|
|
) -> int:
|
2023-01-16 01:47:57 +00:00
|
|
|
return min(max(int(args.get(key, default_value)), min_value), max_value)
|
|
|
|
|
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]:
|
2023-01-22 21:44:31 +00:00
|
|
|
selected = args.get(key, None)
|
2023-01-17 02:10:52 +00:00
|
|
|
if selected in values:
|
|
|
|
return selected
|
2023-01-22 21:44:31 +00:00
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.warn("invalid selection: %s", selected)
|
2023-01-22 21:44:31 +00:00
|
|
|
if len(values) > 0:
|
2023-01-17 02:10:52 +00:00
|
|
|
return values[0]
|
|
|
|
|
2023-01-22 21:44:31 +00:00
|
|
|
return None
|
|
|
|
|
2023-01-17 02:10:52 +00:00
|
|
|
|
2023-01-22 21:44:31 +00:00
|
|
|
def get_from_map(args: Any, key: str, values: Dict[str, Any], default: Any) -> Any:
|
2023-01-16 01:47:57 +00:00
|
|
|
selected = args.get(key, default)
|
|
|
|
if selected in values:
|
|
|
|
return values[selected]
|
|
|
|
else:
|
|
|
|
return values[default]
|
|
|
|
|
|
|
|
|
2023-01-22 21:44:31 +00:00
|
|
|
def get_not_empty(args: Any, key: str, default: Any) -> Any:
|
2023-01-22 19:48:14 +00:00
|
|
|
val = args.get(key, default)
|
|
|
|
|
|
|
|
if val is None or len(val) == 0:
|
|
|
|
val = default
|
|
|
|
|
|
|
|
return val
|
|
|
|
|
|
|
|
|
2023-02-20 04:10:35 +00:00
|
|
|
def get_size(val: Union[int, str, None]) -> Union[int, SizeChart]:
|
2023-01-29 05:06:25 +00:00
|
|
|
if val is None:
|
|
|
|
return SizeChart.auto
|
|
|
|
|
|
|
|
if type(val) is int:
|
|
|
|
return val
|
|
|
|
|
2023-01-29 05:46:36 +00:00
|
|
|
if type(val) is str:
|
|
|
|
for size in SizeChart:
|
|
|
|
if val == size.name:
|
|
|
|
return size
|
|
|
|
|
|
|
|
return int(val)
|
|
|
|
|
2023-02-20 04:10:35 +00:00
|
|
|
raise ValueError("invalid size")
|
2023-01-29 05:06:25 +00:00
|
|
|
|
|
|
|
|
2023-02-19 13:41:16 +00:00
|
|
|
def run_gc(devices: List[DeviceParams] = None):
|
2023-02-18 04:49:13 +00:00
|
|
|
logger.debug(
|
|
|
|
"running garbage collection with %s active threads", threading.active_count()
|
|
|
|
)
|
2023-02-02 03:20:48 +00:00
|
|
|
gc.collect()
|
2023-02-17 00:11:35 +00:00
|
|
|
|
2023-02-19 13:41:16 +00:00
|
|
|
if torch.cuda.is_available() and devices is not None:
|
2023-02-26 04:32:01 +00:00
|
|
|
for device in [d for d in devices if d.device.startswith("cuda")]:
|
2023-02-17 00:11:35 +00:00
|
|
|
logger.debug("running Torch garbage collection for device: %s", device)
|
|
|
|
with torch.cuda.device(device.torch_str()):
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.ipc_collect()
|
2023-02-17 14:44:42 +00:00
|
|
|
mem_free, mem_total = torch.cuda.mem_get_info()
|
2023-02-17 05:51:17 +00:00
|
|
|
logger.debug(
|
|
|
|
"remaining CUDA VRAM usage: %s of %s",
|
|
|
|
(mem_total - mem_free),
|
|
|
|
mem_total,
|
|
|
|
)
|