1
0
Fork 0

fix(scripts): check Torch CUDA devices (#242)

This commit is contained in:
Sean Sube 2023-03-18 13:54:12 -05:00
parent a9456f4a16
commit 391a707c84
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 21 additions and 0 deletions

View File

@ -59,10 +59,31 @@ def check_providers() -> List[str]:
return results return results
def check_torch_cuda() -> List[str]:
results = []
try:
import torch.cuda
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
with torch.cuda.device(i):
mem_free, mem_total = torch.cuda.mem_get_info()
results.append(f"Torch has CUDA device {i} with {mem_free} of {mem_total} bytes of free VRAM")
else:
results.append("loaded Torch but CUDA was not available")
except ImportError as e:
results.append(f"unable to import Torch CUDA: {e}")
except Exception as e:
results.append(f"error listing Torch CUDA: {e}")
return results
ALL_CHECKS = [ ALL_CHECKS = [
check_modules, check_modules,
check_runtimes, check_runtimes,
check_providers, check_providers,
check_torch_cuda,
] ]