fix(scripts): check for ORT modules
This commit is contained in:
parent
84718e5928
commit
c465b61fb5
|
@ -3,7 +3,7 @@ from typing import List
|
||||||
|
|
||||||
REQUIRED_MODULES = ["onnx", "diffusers", "safetensors", "torch"]
|
REQUIRED_MODULES = ["onnx", "diffusers", "safetensors", "torch"]
|
||||||
|
|
||||||
REQUIRED_RUNTIME = [
|
RUNTIME_MODULES = [
|
||||||
"onnxruntime",
|
"onnxruntime",
|
||||||
"onnxruntime_gpu",
|
"onnxruntime_gpu",
|
||||||
"onnxruntime_rocm",
|
"onnxruntime_rocm",
|
||||||
|
@ -26,6 +26,21 @@ def check_modules() -> List[str]:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def check_runtimes() -> List[str]:
|
||||||
|
results = []
|
||||||
|
for name in RUNTIME_MODULES:
|
||||||
|
try:
|
||||||
|
__import__(name)
|
||||||
|
module_version = version(name)
|
||||||
|
results.append(
|
||||||
|
f"runtime module {name} is present at version {module_version}"
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
results.append(f"unable to import runtime module {name}: {e}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def check_providers() -> List[str]:
|
def check_providers() -> List[str]:
|
||||||
results = []
|
results = []
|
||||||
try:
|
try:
|
||||||
|
@ -46,6 +61,7 @@ def check_providers() -> List[str]:
|
||||||
|
|
||||||
ALL_CHECKS = [
|
ALL_CHECKS = [
|
||||||
check_modules,
|
check_modules,
|
||||||
|
check_runtimes,
|
||||||
check_providers,
|
check_providers,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue