From c768cd8f42d31787160cff5751b90f7cbbe2d63d Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 21 Jan 2023 19:40:10 -0600 Subject: [PATCH] fix(api): filter platforms based on available providers (fixes #69) --- api/onnx_web/serve.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 0c4b75a8..0c8de54d 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -19,6 +19,7 @@ from flask_executor import Executor from glob import glob from io import BytesIO from PIL import Image +from onnxruntime import get_available_providers from os import makedirs, path, scandir from typing import Tuple @@ -72,7 +73,7 @@ platform_providers = { 'cuda': 'CUDAExecutionProvider', 'directml': 'DmlExecutionProvider', 'nvidia': 'CUDAExecutionProvider', - 'rocm': 'ROCmExecutionProvider', + 'rocm': 'ROCMExecutionProvider', } pipeline_schedulers = { 'ddim': DDIMScheduler, @@ -102,6 +103,9 @@ mask_filters = { 'gaussian-screen': mask_filter_gaussian_screen, } +# Available ORT providers +available_platforms = [] + # loaded from model_path diffusion_models = [] correction_models = [] @@ -248,11 +252,21 @@ def load_params(context: ServerContext): config_params = json.load(f) +def load_platforms(): + global available_platforms + + providers = get_available_providers() + available_platforms = [p for p in platform_providers if (platform_providers[p] in providers)] + + print('available acceleration platforms: %s' % (available_platforms)) + + context = ServerContext.from_environ() check_paths(context) load_models(context) load_params(context) +load_platforms() app = Flask(__name__) app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers