1
0
Fork 0

feat(api): add provider for each available CUDA device (#38)

This commit is contained in:
Sean Sube 2023-02-04 13:49:34 -06:00
parent f6dbab3422
commit 98b6e4dd03
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
5 changed files with 50 additions and 19 deletions

View File

@ -3,6 +3,10 @@ from logging import getLogger
from multiprocessing import Value
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .params import (
DeviceParams,
)
logger = getLogger(__name__)
@ -10,7 +14,7 @@ class JobContext:
def __init__(
self,
key: str,
devices: List[str],
devices: List[DeviceParams],
cancel: bool = False,
device_index: int = -1,
progress: int = 0,
@ -24,7 +28,7 @@ class JobContext:
def is_cancelled(self) -> bool:
return self.cancel.value
def get_device(self) -> str:
def get_device(self) -> DeviceParams:
'''
Get the device assigned to this job.
'''
@ -45,7 +49,8 @@ class JobContext:
if self.is_cancelled():
raise Exception('job has been cancelled')
else:
logger.debug('setting progress for job %s to %s', self.key, step)
logger.debug('setting progress for job %s to %s',
self.key, step)
self.set_progress(step)
return on_progress
@ -63,6 +68,7 @@ class Job:
'''
Link a future to its context.
'''
def __init__(
self,
key: str,
@ -88,16 +94,18 @@ class DevicePoolExecutor:
jobs: List[Job] = None
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
self.devices = devices
self.jobs = []
device_count = len(devices)
if pool is None:
logger.info('creating thread pool executor for %s devices: %s', device_count, devices)
logger.info(
'creating thread pool executor for %s devices: %s', device_count, devices)
self.pool = ThreadPoolExecutor(device_count)
else:
logger.info('using existing pool for %s devices: %s', device_count, devices)
logger.info('using existing pool for %s devices: %s',
device_count, devices)
self.pool = pool
def cancel(self, key: str) -> bool:

View File

@ -50,7 +50,7 @@ def run_txt2img_pipeline(
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler, device=device)
params.model, device.provider, params.scheduler, device=device.torch_device())
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
@ -92,7 +92,7 @@ def run_img2img_pipeline(
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler, device=device)
params.model, device.provider, params.scheduler, device=device.torch_device())
rng = np.random.RandomState(params.seed)
@ -137,8 +137,8 @@ def run_inpaint_pipeline(
strength: float,
fill_color: str,
) -> None:
device = job.get_device()
progress = job.get_progress_callback()
# device = job.get_device()
# progress = job.get_progress_callback()
stage = StageParams()
# TODO: pass device, progress
@ -182,8 +182,8 @@ def run_upscale_pipeline(
upscale: UpscaleParams,
source_image: Image.Image,
) -> None:
device = job.get_device()
progress = job.get_progress_callback()
# device = job.get_device()
# progress = job.get_progress_callback()
stage = StageParams()
# TODO: pass device, progress

View File

@ -59,6 +59,19 @@ class Size:
}
class DeviceParams:
def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None:
self.device = device
self.provider = provider
self.options = options
def torch_device(self) -> str:
if self.device.startswith('cuda'):
return self.device
else:
return 'cpu'
class ImageParams:
def __init__(
self,

View File

@ -15,7 +15,6 @@ from diffusers import (
)
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
from flask_cors import CORS
from flask_executor import Executor
from glob import glob
from io import BytesIO
from jsonschema import validate
@ -23,7 +22,7 @@ from logging import getLogger
from PIL import Image
from onnxruntime import get_available_providers
from os import makedirs, path
from typing import Tuple
from typing import List, Tuple
from .chain import (
@ -69,6 +68,7 @@ from .output import (
)
from .params import (
Border,
DeviceParams,
ImageParams,
Size,
StageParams,
@ -88,6 +88,7 @@ from .utils import (
import gc
import numpy as np
import torch
import yaml
logger = getLogger(__name__)
@ -147,7 +148,7 @@ chain_stages = {
}
# Available ORT providers
available_platforms = []
available_platforms: List[DeviceParams] = []
# loaded from model_path
diffusion_models = []
@ -310,8 +311,16 @@ def load_platforms():
global available_platforms
providers = get_available_providers()
available_platforms = [p for p in platform_providers if (
platform_providers[p] in providers and p not in context.block_platforms)]
for potential in platform_providers:
if platform_providers[potential] in providers and potential not in context.block_platforms:
if potential == 'cuda':
for i in range(torch.cuda.device_count()):
available_platforms.append(DeviceParams('%s:%s' % (potential, i), providers[potential], {
'device_id': i,
}))
else:
available_platforms.append(DeviceParams(potential, providers[potential]))
logger.info('available acceleration platforms: %s', available_platforms)
@ -404,7 +413,7 @@ def list_params():
@app.route('/api/settings/platforms')
def list_platforms():
return jsonify(list(available_platforms))
return jsonify([p.device for p in available_platforms])
@app.route('/api/settings/schedulers')

View File

@ -48,6 +48,7 @@
"randn",
"realesr",
"resrgan",
"rocm",
"RRDB",
"runwayml",
"scandir",