fix(api): filter platforms based on available providers (fixes #69)
This commit is contained in:
parent
3a5bae6d0d
commit
c768cd8f42
|
@ -19,6 +19,7 @@ from flask_executor import Executor
|
|||
from glob import glob
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
from onnxruntime import get_available_providers
|
||||
from os import makedirs, path, scandir
|
||||
from typing import Tuple
|
||||
|
||||
|
@ -72,7 +73,7 @@ platform_providers = {
|
|||
'cuda': 'CUDAExecutionProvider',
|
||||
'directml': 'DmlExecutionProvider',
|
||||
'nvidia': 'CUDAExecutionProvider',
|
||||
'rocm': 'ROCmExecutionProvider',
|
||||
'rocm': 'ROCMExecutionProvider',
|
||||
}
|
||||
pipeline_schedulers = {
|
||||
'ddim': DDIMScheduler,
|
||||
|
@ -102,6 +103,9 @@ mask_filters = {
|
|||
'gaussian-screen': mask_filter_gaussian_screen,
|
||||
}
|
||||
|
||||
# Available ORT providers
|
||||
available_platforms = []
|
||||
|
||||
# loaded from model_path
|
||||
diffusion_models = []
|
||||
correction_models = []
|
||||
|
@ -248,11 +252,21 @@ def load_params(context: ServerContext):
|
|||
config_params = json.load(f)
|
||||
|
||||
|
||||
def load_platforms():
|
||||
global available_platforms
|
||||
|
||||
providers = get_available_providers()
|
||||
available_platforms = [p for p in platform_providers if (platform_providers[p] in providers)]
|
||||
|
||||
print('available acceleration platforms: %s' % (available_platforms))
|
||||
|
||||
|
||||
context = ServerContext.from_environ()
|
||||
|
||||
check_paths(context)
|
||||
load_models(context)
|
||||
load_params(context)
|
||||
load_platforms()
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
|
||||
|
|
Loading…
Reference in New Issue