diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/device_pool.py index 2c65bdcd..3c6d3ffe 100644 --- a/api/onnx_web/device_pool.py +++ b/api/onnx_web/device_pool.py @@ -3,6 +3,10 @@ from logging import getLogger from multiprocessing import Value from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from .params import ( + DeviceParams, +) + logger = getLogger(__name__) @@ -10,7 +14,7 @@ class JobContext: def __init__( self, key: str, - devices: List[str], + devices: List[DeviceParams], cancel: bool = False, device_index: int = -1, progress: int = 0, @@ -24,7 +28,7 @@ class JobContext: def is_cancelled(self) -> bool: return self.cancel.value - def get_device(self) -> str: + def get_device(self) -> DeviceParams: ''' Get the device assigned to this job. ''' @@ -45,7 +49,8 @@ class JobContext: if self.is_cancelled(): raise Exception('job has been cancelled') else: - logger.debug('setting progress for job %s to %s', self.key, step) + logger.debug('setting progress for job %s to %s', + self.key, step) self.set_progress(step) return on_progress @@ -63,6 +68,7 @@ class Job: ''' Link a future to its context. ''' + def __init__( self, key: str, @@ -88,16 +94,18 @@ class DevicePoolExecutor: jobs: List[Job] = None pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None - def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None): + def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None): self.devices = devices self.jobs = [] device_count = len(devices) if pool is None: - logger.info('creating thread pool executor for %s devices: %s', device_count, devices) + logger.info( + 'creating thread pool executor for %s devices: %s', device_count, devices) self.pool = ThreadPoolExecutor(device_count) else: - logger.info('using existing pool for %s devices: %s', device_count, devices) + logger.info('using existing pool for %s devices: %s', + device_count, devices) self.pool = pool def cancel(self, key: str) -> bool: @@ -142,4 +150,4 @@ class DevicePoolExecutor: future.add_done_callback(job_done) def status(self) -> Dict[str, Tuple[bool, int]]: - return [(job.key, job.future.done(), job.get_progress()) for job in self.jobs] \ No newline at end of file + return [(job.key, job.future.done(), job.get_progress()) for job in self.jobs] diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index ceee0c94..4bf3cccc 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -50,7 +50,7 @@ def run_txt2img_pipeline( ) -> None: device = job.get_device() pipe = load_pipeline(OnnxStableDiffusionPipeline, - params.model, params.provider, params.scheduler, device=device) + params.model, device.provider, params.scheduler, device=device.torch_device()) latents = get_latents_from_seed(params.seed, size) rng = np.random.RandomState(params.seed) @@ -92,7 +92,7 @@ def run_img2img_pipeline( ) -> None: device = job.get_device() pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, - params.model, params.provider, params.scheduler, device=device) + params.model, device.provider, params.scheduler, device=device.torch_device()) rng = np.random.RandomState(params.seed) @@ -137,8 +137,8 @@ def run_inpaint_pipeline( strength: float, fill_color: str, ) -> None: - device = job.get_device() - progress = job.get_progress_callback() + # device = job.get_device() + # progress = job.get_progress_callback() stage = StageParams() # TODO: pass device, progress @@ -182,8 +182,8 @@ def run_upscale_pipeline( upscale: UpscaleParams, source_image: Image.Image, ) -> None: - device = job.get_device() - progress = job.get_progress_callback() + # device = job.get_device() + # progress = job.get_progress_callback() stage = StageParams() # TODO: pass device, progress diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index a087de42..04ae6413 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -59,6 +59,19 @@ class Size: } +class DeviceParams: + def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None: + self.device = device + self.provider = provider + self.options = options + + def torch_device(self) -> str: + if self.device.startswith('cuda'): + return self.device + else: + return 'cpu' + + class ImageParams: def __init__( self, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 157fc6a6..79adce4e 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -15,7 +15,6 @@ from diffusers import ( ) from flask import Flask, jsonify, make_response, request, send_from_directory, url_for from flask_cors import CORS -from flask_executor import Executor from glob import glob from io import BytesIO from jsonschema import validate @@ -23,7 +22,7 @@ from logging import getLogger from PIL import Image from onnxruntime import get_available_providers from os import makedirs, path -from typing import Tuple +from typing import List, Tuple from .chain import ( @@ -69,6 +68,7 @@ from .output import ( ) from .params import ( Border, + DeviceParams, ImageParams, Size, StageParams, @@ -88,6 +88,7 @@ from .utils import ( import gc import numpy as np +import torch import yaml logger = getLogger(__name__) @@ -147,7 +148,7 @@ chain_stages = { } # Available ORT providers -available_platforms = [] +available_platforms: List[DeviceParams] = [] # loaded from model_path diffusion_models = [] @@ -310,8 +311,16 @@ def load_platforms(): global available_platforms providers = get_available_providers() - available_platforms = [p for p in platform_providers if ( - platform_providers[p] in providers and p not in context.block_platforms)] + + for potential in platform_providers: + if platform_providers[potential] in providers and potential not in context.block_platforms: + if potential == 'cuda': + for i in range(torch.cuda.device_count()): + available_platforms.append(DeviceParams('%s:%s' % (potential, i), providers[potential], { + 'device_id': i, + })) + else: + available_platforms.append(DeviceParams(potential, providers[potential])) logger.info('available acceleration platforms: %s', available_platforms) @@ -404,7 +413,7 @@ def list_params(): @app.route('/api/settings/platforms') def list_platforms(): - return jsonify(list(available_platforms)) + return jsonify([p.device for p in available_platforms]) @app.route('/api/settings/schedulers') diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index 64c0f889..6447e815 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -48,6 +48,7 @@ "randn", "realesr", "resrgan", + "rocm", "RRDB", "runwayml", "scandir",