fix(api): show VRAM percent in logs
This commit is contained in:
parent
7a3a81a4ef
commit
39b9741b24
|
@ -97,10 +97,12 @@ def run_gc(devices: Optional[List[DeviceParams]] = None):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
mem_free, mem_total = torch.cuda.mem_get_info()
|
mem_free, mem_total = torch.cuda.mem_get_info()
|
||||||
|
mem_pct = (1 - (mem_free / mem_total)) * 100
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"remaining CUDA VRAM usage: %s of %s",
|
"CUDA VRAM usage: %s of %s (%.2f%%)",
|
||||||
(mem_total - mem_free),
|
(mem_total - mem_free),
|
||||||
mem_total,
|
mem_total,
|
||||||
|
mem_pct,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -291,7 +291,9 @@ class DevicePoolExecutor:
|
||||||
logger.debug("shutting down worker for device %s", device)
|
logger.debug("shutting down worker for device %s", device)
|
||||||
worker.join(self.join_timeout)
|
worker.join(self.join_timeout)
|
||||||
if worker.is_alive():
|
if worker.is_alive():
|
||||||
logger.error("leaking worker for device %s could not be shut down", device)
|
logger.error(
|
||||||
|
"leaking worker for device %s could not be shut down", device
|
||||||
|
)
|
||||||
|
|
||||||
self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()]
|
self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue