1
0
Fork 0

feat(api): add provider for each available CUDA device (#38)

This commit is contained in:
Sean Sube 2023-02-04 13:49:34 -06:00
parent f6dbab3422
commit 98b6e4dd03
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 50 additions and 19 deletions

View File

@ -3,6 +3,10 @@ from logging import getLogger
from multiprocessing import Value from multiprocessing import Value
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .params import (
DeviceParams,
)
logger = getLogger(__name__) logger = getLogger(__name__)
@ -10,7 +14,7 @@ class JobContext:
def __init__( def __init__(
self, self,
key: str, key: str,
devices: List[str], devices: List[DeviceParams],
cancel: bool = False, cancel: bool = False,
device_index: int = -1, device_index: int = -1,
progress: int = 0, progress: int = 0,
@ -24,7 +28,7 @@ class JobContext:
def is_cancelled(self) -> bool: def is_cancelled(self) -> bool:
return self.cancel.value return self.cancel.value
def get_device(self) -> str: def get_device(self) -> DeviceParams:
''' '''
Get the device assigned to this job. Get the device assigned to this job.
''' '''
@ -45,7 +49,8 @@ class JobContext:
if self.is_cancelled(): if self.is_cancelled():
raise Exception('job has been cancelled') raise Exception('job has been cancelled')
else: else:
logger.debug('setting progress for job %s to %s', self.key, step) logger.debug('setting progress for job %s to %s',
self.key, step)
self.set_progress(step) self.set_progress(step)
return on_progress return on_progress
@ -63,6 +68,7 @@ class Job:
''' '''
Link a future to its context. Link a future to its context.
''' '''
def __init__( def __init__(
self, self,
key: str, key: str,
@ -88,16 +94,18 @@ class DevicePoolExecutor:
jobs: List[Job] = None jobs: List[Job] = None
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None): def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
self.devices = devices self.devices = devices
self.jobs = [] self.jobs = []
device_count = len(devices) device_count = len(devices)
if pool is None: if pool is None:
logger.info('creating thread pool executor for %s devices: %s', device_count, devices) logger.info(
'creating thread pool executor for %s devices: %s', device_count, devices)
self.pool = ThreadPoolExecutor(device_count) self.pool = ThreadPoolExecutor(device_count)
else: else:
logger.info('using existing pool for %s devices: %s', device_count, devices) logger.info('using existing pool for %s devices: %s',
device_count, devices)
self.pool = pool self.pool = pool
def cancel(self, key: str) -> bool: def cancel(self, key: str) -> bool:
@ -142,4 +150,4 @@ class DevicePoolExecutor:
future.add_done_callback(job_done) future.add_done_callback(job_done)
def status(self) -> Dict[str, Tuple[bool, int]]: def status(self) -> Dict[str, Tuple[bool, int]]:
return [(job.key, job.future.done(), job.get_progress()) for job in self.jobs] return [(job.key, job.future.done(), job.get_progress()) for job in self.jobs]

View File

@ -50,7 +50,7 @@ def run_txt2img_pipeline(
) -> None: ) -> None:
device = job.get_device() device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionPipeline, pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler, device=device) params.model, device.provider, params.scheduler, device=device.torch_device())
latents = get_latents_from_seed(params.seed, size) latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -92,7 +92,7 @@ def run_img2img_pipeline(
) -> None: ) -> None:
device = job.get_device() device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler, device=device) params.model, device.provider, params.scheduler, device=device.torch_device())
rng = np.random.RandomState(params.seed) rng = np.random.RandomState(params.seed)
@ -137,8 +137,8 @@ def run_inpaint_pipeline(
strength: float, strength: float,
fill_color: str, fill_color: str,
) -> None: ) -> None:
device = job.get_device() # device = job.get_device()
progress = job.get_progress_callback() # progress = job.get_progress_callback()
stage = StageParams() stage = StageParams()
# TODO: pass device, progress # TODO: pass device, progress
@ -182,8 +182,8 @@ def run_upscale_pipeline(
upscale: UpscaleParams, upscale: UpscaleParams,
source_image: Image.Image, source_image: Image.Image,
) -> None: ) -> None:
device = job.get_device() # device = job.get_device()
progress = job.get_progress_callback() # progress = job.get_progress_callback()
stage = StageParams() stage = StageParams()
# TODO: pass device, progress # TODO: pass device, progress

View File

@ -59,6 +59,19 @@ class Size:
} }
class DeviceParams:
def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None:
self.device = device
self.provider = provider
self.options = options
def torch_device(self) -> str:
if self.device.startswith('cuda'):
return self.device
else:
return 'cpu'
class ImageParams: class ImageParams:
def __init__( def __init__(
self, self,

View File

@ -15,7 +15,6 @@ from diffusers import (
) )
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS from flask_cors import CORS
from flask_executor import Executor
from glob import glob from glob import glob
from io import BytesIO from io import BytesIO
from jsonschema import validate from jsonschema import validate
@ -23,7 +22,7 @@ from logging import getLogger
from PIL import Image from PIL import Image
from onnxruntime import get_available_providers from onnxruntime import get_available_providers
from os import makedirs, path from os import makedirs, path
from typing import Tuple from typing import List, Tuple
from .chain import ( from .chain import (
@ -69,6 +68,7 @@ from .output import (
) )
from .params import ( from .params import (
Border, Border,
DeviceParams,
ImageParams, ImageParams,
Size, Size,
StageParams, StageParams,
@ -88,6 +88,7 @@ from .utils import (
import gc import gc
import numpy as np import numpy as np
import torch
import yaml import yaml
logger = getLogger(__name__) logger = getLogger(__name__)
@ -147,7 +148,7 @@ chain_stages = {
} }
# Available ORT providers # Available ORT providers
available_platforms = [] available_platforms: List[DeviceParams] = []
# loaded from model_path # loaded from model_path
diffusion_models = [] diffusion_models = []
@ -310,8 +311,16 @@ def load_platforms():
global available_platforms global available_platforms
providers = get_available_providers() providers = get_available_providers()
available_platforms = [p for p in platform_providers if (
platform_providers[p] in providers and p not in context.block_platforms)] for potential in platform_providers:
if platform_providers[potential] in providers and potential not in context.block_platforms:
if potential == 'cuda':
for i in range(torch.cuda.device_count()):
available_platforms.append(DeviceParams('%s:%s' % (potential, i), providers[potential], {
'device_id': i,
}))
else:
available_platforms.append(DeviceParams(potential, providers[potential]))
logger.info('available acceleration platforms: %s', available_platforms) logger.info('available acceleration platforms: %s', available_platforms)
@ -404,7 +413,7 @@ def list_params():
@app.route('/api/settings/platforms') @app.route('/api/settings/platforms')
def list_platforms(): def list_platforms():
return jsonify(list(available_platforms)) return jsonify([p.device for p in available_platforms])
@app.route('/api/settings/schedulers') @app.route('/api/settings/schedulers')

View File

@ -48,6 +48,7 @@
"randn", "randn",
"realesr", "realesr",
"resrgan", "resrgan",
"rocm",
"RRDB", "RRDB",
"runwayml", "runwayml",
"scandir", "scandir",