add missing file
This commit is contained in:
parent
028547d90b
commit
dd02b059ba
|
@ -0,0 +1,32 @@
|
||||||
|
import gc
|
||||||
|
import threading
|
||||||
|
from logging import getLogger
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .params import DeviceParams
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_gc(devices: Optional[List[DeviceParams]] = None):
|
||||||
|
logger.debug(
|
||||||
|
"running garbage collection with %s active threads", threading.active_count()
|
||||||
|
)
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if torch.cuda.is_available() and devices is not None:
|
||||||
|
for device in [d for d in devices if d.device.startswith("cuda")]:
|
||||||
|
logger.debug("running Torch garbage collection for device: %s", device)
|
||||||
|
with torch.cuda.device(device.torch_str()):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
mem_free, mem_total = torch.cuda.mem_get_info()
|
||||||
|
mem_pct = (1 - (mem_free / mem_total)) * 100
|
||||||
|
logger.debug(
|
||||||
|
"CUDA VRAM usage: %s of %s (%.2f%%)",
|
||||||
|
(mem_total - mem_free),
|
||||||
|
mem_total,
|
||||||
|
mem_pct,
|
||||||
|
)
|
Loading…
Reference in New Issue