From 391a707c846109fd05ab03f302400e99bd8931d7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Mar 2023 13:54:12 -0500 Subject: [PATCH] fix(scripts): check Torch CUDA devices (#242) --- api/scripts/check-env.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/api/scripts/check-env.py b/api/scripts/check-env.py index cf7b8b16..aa8e228f 100644 --- a/api/scripts/check-env.py +++ b/api/scripts/check-env.py @@ -59,10 +59,31 @@ def check_providers() -> List[str]: return results +def check_torch_cuda() -> List[str]: + results = [] + try: + import torch.cuda + + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + with torch.cuda.device(i): + mem_free, mem_total = torch.cuda.mem_get_info() + results.append(f"Torch has CUDA device {i} with {mem_free} of {mem_total} bytes of free VRAM") + else: + results.append("loaded Torch but CUDA was not available") + except ImportError as e: + results.append(f"unable to import Torch CUDA: {e}") + except Exception as e: + results.append(f"error listing Torch CUDA: {e}") + + return results + + ALL_CHECKS = [ check_modules, check_runtimes, check_providers, + check_torch_cuda, ]