feat(api): add provider for each available CUDA device (#38)
This commit is contained in:
parent
f6dbab3422
commit
98b6e4dd03
|
@ -3,6 +3,10 @@ from logging import getLogger
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from .params import (
|
||||||
|
DeviceParams,
|
||||||
|
)
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,7 +14,7 @@ class JobContext:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
devices: List[str],
|
devices: List[DeviceParams],
|
||||||
cancel: bool = False,
|
cancel: bool = False,
|
||||||
device_index: int = -1,
|
device_index: int = -1,
|
||||||
progress: int = 0,
|
progress: int = 0,
|
||||||
|
@ -24,7 +28,7 @@ class JobContext:
|
||||||
def is_cancelled(self) -> bool:
|
def is_cancelled(self) -> bool:
|
||||||
return self.cancel.value
|
return self.cancel.value
|
||||||
|
|
||||||
def get_device(self) -> str:
|
def get_device(self) -> DeviceParams:
|
||||||
'''
|
'''
|
||||||
Get the device assigned to this job.
|
Get the device assigned to this job.
|
||||||
'''
|
'''
|
||||||
|
@ -45,7 +49,8 @@ class JobContext:
|
||||||
if self.is_cancelled():
|
if self.is_cancelled():
|
||||||
raise Exception('job has been cancelled')
|
raise Exception('job has been cancelled')
|
||||||
else:
|
else:
|
||||||
logger.debug('setting progress for job %s to %s', self.key, step)
|
logger.debug('setting progress for job %s to %s',
|
||||||
|
self.key, step)
|
||||||
self.set_progress(step)
|
self.set_progress(step)
|
||||||
|
|
||||||
return on_progress
|
return on_progress
|
||||||
|
@ -63,6 +68,7 @@ class Job:
|
||||||
'''
|
'''
|
||||||
Link a future to its context.
|
Link a future to its context.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
|
@ -88,16 +94,18 @@ class DevicePoolExecutor:
|
||||||
jobs: List[Job] = None
|
jobs: List[Job] = None
|
||||||
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
||||||
|
|
||||||
def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
|
def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
|
||||||
self.devices = devices
|
self.devices = devices
|
||||||
self.jobs = []
|
self.jobs = []
|
||||||
|
|
||||||
device_count = len(devices)
|
device_count = len(devices)
|
||||||
if pool is None:
|
if pool is None:
|
||||||
logger.info('creating thread pool executor for %s devices: %s', device_count, devices)
|
logger.info(
|
||||||
|
'creating thread pool executor for %s devices: %s', device_count, devices)
|
||||||
self.pool = ThreadPoolExecutor(device_count)
|
self.pool = ThreadPoolExecutor(device_count)
|
||||||
else:
|
else:
|
||||||
logger.info('using existing pool for %s devices: %s', device_count, devices)
|
logger.info('using existing pool for %s devices: %s',
|
||||||
|
device_count, devices)
|
||||||
self.pool = pool
|
self.pool = pool
|
||||||
|
|
||||||
def cancel(self, key: str) -> bool:
|
def cancel(self, key: str) -> bool:
|
||||||
|
|
|
@ -50,7 +50,7 @@ def run_txt2img_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||||
params.model, params.provider, params.scheduler, device=device)
|
params.model, device.provider, params.scheduler, device=device.torch_device())
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -92,7 +92,7 @@ def run_img2img_pipeline(
|
||||||
) -> None:
|
) -> None:
|
||||||
device = job.get_device()
|
device = job.get_device()
|
||||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||||
params.model, params.provider, params.scheduler, device=device)
|
params.model, device.provider, params.scheduler, device=device.torch_device())
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
||||||
|
@ -137,8 +137,8 @@ def run_inpaint_pipeline(
|
||||||
strength: float,
|
strength: float,
|
||||||
fill_color: str,
|
fill_color: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
device = job.get_device()
|
# device = job.get_device()
|
||||||
progress = job.get_progress_callback()
|
# progress = job.get_progress_callback()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
|
|
||||||
# TODO: pass device, progress
|
# TODO: pass device, progress
|
||||||
|
@ -182,8 +182,8 @@ def run_upscale_pipeline(
|
||||||
upscale: UpscaleParams,
|
upscale: UpscaleParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
device = job.get_device()
|
# device = job.get_device()
|
||||||
progress = job.get_progress_callback()
|
# progress = job.get_progress_callback()
|
||||||
stage = StageParams()
|
stage = StageParams()
|
||||||
|
|
||||||
# TODO: pass device, progress
|
# TODO: pass device, progress
|
||||||
|
|
|
@ -59,6 +59,19 @@ class Size:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceParams:
|
||||||
|
def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None:
|
||||||
|
self.device = device
|
||||||
|
self.provider = provider
|
||||||
|
self.options = options
|
||||||
|
|
||||||
|
def torch_device(self) -> str:
|
||||||
|
if self.device.startswith('cuda'):
|
||||||
|
return self.device
|
||||||
|
else:
|
||||||
|
return 'cpu'
|
||||||
|
|
||||||
|
|
||||||
class ImageParams:
|
class ImageParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -15,7 +15,6 @@ from diffusers import (
|
||||||
)
|
)
|
||||||
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
from flask import Flask, jsonify, make_response, request, send_from_directory, url_for
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from flask_executor import Executor
|
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from jsonschema import validate
|
from jsonschema import validate
|
||||||
|
@ -23,7 +22,7 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from onnxruntime import get_available_providers
|
from onnxruntime import get_available_providers
|
||||||
from os import makedirs, path
|
from os import makedirs, path
|
||||||
from typing import Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
from .chain import (
|
from .chain import (
|
||||||
|
@ -69,6 +68,7 @@ from .output import (
|
||||||
)
|
)
|
||||||
from .params import (
|
from .params import (
|
||||||
Border,
|
Border,
|
||||||
|
DeviceParams,
|
||||||
ImageParams,
|
ImageParams,
|
||||||
Size,
|
Size,
|
||||||
StageParams,
|
StageParams,
|
||||||
|
@ -88,6 +88,7 @@ from .utils import (
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -147,7 +148,7 @@ chain_stages = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# Available ORT providers
|
# Available ORT providers
|
||||||
available_platforms = []
|
available_platforms: List[DeviceParams] = []
|
||||||
|
|
||||||
# loaded from model_path
|
# loaded from model_path
|
||||||
diffusion_models = []
|
diffusion_models = []
|
||||||
|
@ -310,8 +311,16 @@ def load_platforms():
|
||||||
global available_platforms
|
global available_platforms
|
||||||
|
|
||||||
providers = get_available_providers()
|
providers = get_available_providers()
|
||||||
available_platforms = [p for p in platform_providers if (
|
|
||||||
platform_providers[p] in providers and p not in context.block_platforms)]
|
for potential in platform_providers:
|
||||||
|
if platform_providers[potential] in providers and potential not in context.block_platforms:
|
||||||
|
if potential == 'cuda':
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
available_platforms.append(DeviceParams('%s:%s' % (potential, i), providers[potential], {
|
||||||
|
'device_id': i,
|
||||||
|
}))
|
||||||
|
else:
|
||||||
|
available_platforms.append(DeviceParams(potential, providers[potential]))
|
||||||
|
|
||||||
logger.info('available acceleration platforms: %s', available_platforms)
|
logger.info('available acceleration platforms: %s', available_platforms)
|
||||||
|
|
||||||
|
@ -404,7 +413,7 @@ def list_params():
|
||||||
|
|
||||||
@app.route('/api/settings/platforms')
|
@app.route('/api/settings/platforms')
|
||||||
def list_platforms():
|
def list_platforms():
|
||||||
return jsonify(list(available_platforms))
|
return jsonify([p.device for p in available_platforms])
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/settings/schedulers')
|
@app.route('/api/settings/schedulers')
|
||||||
|
|
|
@ -48,6 +48,7 @@
|
||||||
"randn",
|
"randn",
|
||||||
"realesr",
|
"realesr",
|
||||||
"resrgan",
|
"resrgan",
|
||||||
|
"rocm",
|
||||||
"RRDB",
|
"RRDB",
|
||||||
"runwayml",
|
"runwayml",
|
||||||
"scandir",
|
"scandir",
|
||||||
|
|
Loading…
Reference in New Issue