fix(scripts): check Torch CUDA devices (#242)
This commit is contained in:
parent
a9456f4a16
commit
391a707c84
|
@ -59,10 +59,31 @@ def check_providers() -> List[str]:
|
|||
return results
|
||||
|
||||
|
||||
def check_torch_cuda() -> List[str]:
|
||||
results = []
|
||||
try:
|
||||
import torch.cuda
|
||||
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
with torch.cuda.device(i):
|
||||
mem_free, mem_total = torch.cuda.mem_get_info()
|
||||
results.append(f"Torch has CUDA device {i} with {mem_free} of {mem_total} bytes of free VRAM")
|
||||
else:
|
||||
results.append("loaded Torch but CUDA was not available")
|
||||
except ImportError as e:
|
||||
results.append(f"unable to import Torch CUDA: {e}")
|
||||
except Exception as e:
|
||||
results.append(f"error listing Torch CUDA: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
ALL_CHECKS = [
|
||||
check_modules,
|
||||
check_runtimes,
|
||||
check_providers,
|
||||
check_torch_cuda,
|
||||
]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue