1
0
Fork 0

fix(api): filter platforms based on available providers (fixes #69)

This commit is contained in:
Sean Sube 2023-01-21 19:40:10 -06:00
parent 3a5bae6d0d
commit c768cd8f42
1 changed files with 15 additions and 1 deletions

View File

@ -19,6 +19,7 @@ from flask_executor import Executor
from glob import glob from glob import glob
from io import BytesIO from io import BytesIO
from PIL import Image from PIL import Image
from onnxruntime import get_available_providers
from os import makedirs, path, scandir from os import makedirs, path, scandir
from typing import Tuple from typing import Tuple
@ -72,7 +73,7 @@ platform_providers = {
'cuda': 'CUDAExecutionProvider', 'cuda': 'CUDAExecutionProvider',
'directml': 'DmlExecutionProvider', 'directml': 'DmlExecutionProvider',
'nvidia': 'CUDAExecutionProvider', 'nvidia': 'CUDAExecutionProvider',
'rocm': 'ROCmExecutionProvider', 'rocm': 'ROCMExecutionProvider',
} }
pipeline_schedulers = { pipeline_schedulers = {
'ddim': DDIMScheduler, 'ddim': DDIMScheduler,
@ -102,6 +103,9 @@ mask_filters = {
'gaussian-screen': mask_filter_gaussian_screen, 'gaussian-screen': mask_filter_gaussian_screen,
} }
# Available ORT providers
available_platforms = []
# loaded from model_path # loaded from model_path
diffusion_models = [] diffusion_models = []
correction_models = [] correction_models = []
@ -248,11 +252,21 @@ def load_params(context: ServerContext):
config_params = json.load(f) config_params = json.load(f)
def load_platforms():
global available_platforms
providers = get_available_providers()
available_platforms = [p for p in platform_providers if (platform_providers[p] in providers)]
print('available acceleration platforms: %s' % (available_platforms))
context = ServerContext.from_environ() context = ServerContext.from_environ()
check_paths(context) check_paths(context)
load_models(context) load_models(context)
load_params(context) load_params(context)
load_platforms()
app = Flask(__name__) app = Flask(__name__)
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers