fix(scripts): check for ORT modules
This commit is contained in:
parent
84718e5928
commit
c465b61fb5
|
@ -3,7 +3,7 @@ from typing import List
|
|||
|
||||
REQUIRED_MODULES = ["onnx", "diffusers", "safetensors", "torch"]
|
||||
|
||||
REQUIRED_RUNTIME = [
|
||||
RUNTIME_MODULES = [
|
||||
"onnxruntime",
|
||||
"onnxruntime_gpu",
|
||||
"onnxruntime_rocm",
|
||||
|
@ -26,6 +26,21 @@ def check_modules() -> List[str]:
|
|||
return results
|
||||
|
||||
|
||||
def check_runtimes() -> List[str]:
|
||||
results = []
|
||||
for name in RUNTIME_MODULES:
|
||||
try:
|
||||
__import__(name)
|
||||
module_version = version(name)
|
||||
results.append(
|
||||
f"runtime module {name} is present at version {module_version}"
|
||||
)
|
||||
except ImportError as e:
|
||||
results.append(f"unable to import runtime module {name}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def check_providers() -> List[str]:
|
||||
results = []
|
||||
try:
|
||||
|
@ -46,6 +61,7 @@ def check_providers() -> List[str]:
|
|||
|
||||
ALL_CHECKS = [
|
||||
check_modules,
|
||||
check_runtimes,
|
||||
check_providers,
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue