diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/device_pool.py index 3568407d..bdc0449a 100644 --- a/api/onnx_web/device_pool.py +++ b/api/onnx_web/device_pool.py @@ -135,8 +135,15 @@ class DevicePoolExecutor: return (None, 0) def get_next_device(self): + # use the first/default device if there are no jobs + if len(self.jobs) == 0: + return 0 + job_devices = [job.context.device_index.value for job in self.jobs] - queued = Counter(job_devices).most_common() + job_counts = Counter(range(len(self.devices))) + job_counts.update(job_devices) + + queued = job_counts.most_common() logger.debug('jobs queued by device: %s', queued) return queued[-1] diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index c418017f..088ed816 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -15,6 +15,7 @@ from diffusers import ( ) from flask import Flask, jsonify, make_response, request, send_from_directory, url_for from flask_cors import CORS +from functools import cmp_to_key from glob import glob from io import BytesIO from jsonschema import validate @@ -333,6 +334,19 @@ def load_platforms(): available_platforms.append(DeviceParams( potential, platform_providers[potential])) + # make sure CPU is last on the list + def cpu_last(a: DeviceParams, b: DeviceParams): + if a.device == 'cpu' and b.device == 'cpu': + return 0 + + if a.device == 'cpu': + return 1 + + return -1 + + + available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last)) + logger.info('available acceleration platforms: %s', ', '.join([str(p) for p in available_platforms]))