1
0
Fork 0

fix(api): filter platforms based on available providers (fixes #69)

This commit is contained in:
Sean Sube 2023-01-21 19:40:10 -06:00
parent 3a5bae6d0d
commit c768cd8f42
1 changed files with 15 additions and 1 deletions

View File

@ -19,6 +19,7 @@ from flask_executor import Executor
from glob import glob
from io import BytesIO
from PIL import Image
from onnxruntime import get_available_providers
from os import makedirs, path, scandir
from typing import Tuple
@ -72,7 +73,7 @@ platform_providers = {
'cuda': 'CUDAExecutionProvider',
'directml': 'DmlExecutionProvider',
'nvidia': 'CUDAExecutionProvider',
'rocm': 'ROCmExecutionProvider',
'rocm': 'ROCMExecutionProvider',
}
pipeline_schedulers = {
'ddim': DDIMScheduler,
@ -102,6 +103,9 @@ mask_filters = {
'gaussian-screen': mask_filter_gaussian_screen,
}
# Available ORT providers
available_platforms = []
# loaded from model_path
diffusion_models = []
correction_models = []
@ -248,11 +252,21 @@ def load_params(context: ServerContext):
config_params = json.load(f)
def load_platforms():
global available_platforms
providers = get_available_providers()
available_platforms = [p for p in platform_providers if (platform_providers[p] in providers)]
print('available acceleration platforms: %s' % (available_platforms))
context = ServerContext.from_environ()
check_paths(context)
load_models(context)
load_params(context)
load_platforms()
app = Flask(__name__)
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers