fix(api): filter platforms based on available providers (fixes #69)
This commit is contained in:
parent
3a5bae6d0d
commit
c768cd8f42
|
@ -19,6 +19,7 @@ from flask_executor import Executor
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from onnxruntime import get_available_providers
|
||||||
from os import makedirs, path, scandir
|
from os import makedirs, path, scandir
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
@ -72,7 +73,7 @@ platform_providers = {
|
||||||
'cuda': 'CUDAExecutionProvider',
|
'cuda': 'CUDAExecutionProvider',
|
||||||
'directml': 'DmlExecutionProvider',
|
'directml': 'DmlExecutionProvider',
|
||||||
'nvidia': 'CUDAExecutionProvider',
|
'nvidia': 'CUDAExecutionProvider',
|
||||||
'rocm': 'ROCmExecutionProvider',
|
'rocm': 'ROCMExecutionProvider',
|
||||||
}
|
}
|
||||||
pipeline_schedulers = {
|
pipeline_schedulers = {
|
||||||
'ddim': DDIMScheduler,
|
'ddim': DDIMScheduler,
|
||||||
|
@ -102,6 +103,9 @@ mask_filters = {
|
||||||
'gaussian-screen': mask_filter_gaussian_screen,
|
'gaussian-screen': mask_filter_gaussian_screen,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Available ORT providers
|
||||||
|
available_platforms = []
|
||||||
|
|
||||||
# loaded from model_path
|
# loaded from model_path
|
||||||
diffusion_models = []
|
diffusion_models = []
|
||||||
correction_models = []
|
correction_models = []
|
||||||
|
@ -248,11 +252,21 @@ def load_params(context: ServerContext):
|
||||||
config_params = json.load(f)
|
config_params = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def load_platforms():
|
||||||
|
global available_platforms
|
||||||
|
|
||||||
|
providers = get_available_providers()
|
||||||
|
available_platforms = [p for p in platform_providers if (platform_providers[p] in providers)]
|
||||||
|
|
||||||
|
print('available acceleration platforms: %s' % (available_platforms))
|
||||||
|
|
||||||
|
|
||||||
context = ServerContext.from_environ()
|
context = ServerContext.from_environ()
|
||||||
|
|
||||||
check_paths(context)
|
check_paths(context)
|
||||||
load_models(context)
|
load_models(context)
|
||||||
load_params(context)
|
load_params(context)
|
||||||
|
load_platforms()
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
|
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
|
||||||
|
|
Loading…
Reference in New Issue