fix(scripts): check Torch CUDA devices (#242)
This commit is contained in:
parent
a9456f4a16
commit
391a707c84
|
@ -59,10 +59,31 @@ def check_providers() -> List[str]:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def check_torch_cuda() -> List[str]:
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
import torch.cuda
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
with torch.cuda.device(i):
|
||||||
|
mem_free, mem_total = torch.cuda.mem_get_info()
|
||||||
|
results.append(f"Torch has CUDA device {i} with {mem_free} of {mem_total} bytes of free VRAM")
|
||||||
|
else:
|
||||||
|
results.append("loaded Torch but CUDA was not available")
|
||||||
|
except ImportError as e:
|
||||||
|
results.append(f"unable to import Torch CUDA: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
results.append(f"error listing Torch CUDA: {e}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
ALL_CHECKS = [
|
ALL_CHECKS = [
|
||||||
check_modules,
|
check_modules,
|
||||||
check_runtimes,
|
check_runtimes,
|
||||||
check_providers,
|
check_providers,
|
||||||
|
check_torch_cuda,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue