diff --git a/api/scripts/check-env.py b/api/scripts/check-env.py index cf7b8b16..aa8e228f 100644 --- a/api/scripts/check-env.py +++ b/api/scripts/check-env.py @@ -59,10 +59,31 @@ def check_providers() -> List[str]: return results +def check_torch_cuda() -> List[str]: + results = [] + try: + import torch.cuda + + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + with torch.cuda.device(i): + mem_free, mem_total = torch.cuda.mem_get_info() + results.append(f"Torch has CUDA device {i} with {mem_free} of {mem_total} bytes of free VRAM") + else: + results.append("loaded Torch but CUDA was not available") + except ImportError as e: + results.append(f"unable to import Torch CUDA: {e}") + except Exception as e: + results.append(f"error listing Torch CUDA: {e}") + + return results + + ALL_CHECKS = [ check_modules, check_runtimes, check_providers, + check_torch_cuda, ]