1
0
Fork 0

fix(api): set CUDA device in ORT session

This commit is contained in:
Sean Sube 2023-02-04 21:17:39 -06:00
parent d636ce3eef
commit 04a2faffd9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
21 changed files with 116 additions and 52 deletions

6
api/dev-requirements.txt Normal file
View File

@ -0,0 +1,6 @@
mypy
types-Flask-Cors
types-jsonschema
types-Pillow
types-PyYAML

View File

@ -28,6 +28,7 @@ logger = getLogger(__name__)
class StageCallback(Protocol):
def __call__(
self,
job: JobContext,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
@ -83,7 +84,7 @@ class ChainPipeline:
stage_params.tile_size)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(server, stage_params, params, tile,
tile = stage_pipe(job, server, stage_params, params, tile,
**kwargs)
if is_debug():
@ -95,7 +96,7 @@ class ChainPipeline:
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
logger.info('image within tile size, running stage')
image = stage_pipe(server, stage_params, params, image,
image = stage_pipe(job, server, stage_params, params, image,
**kwargs)
logger.info('finished stage %s, result size: %sx%s',

View File

@ -4,6 +4,9 @@ from diffusers import (
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
load_pipeline,
)
@ -21,7 +24,8 @@ logger = getLogger(__name__)
def blend_img2img(
_ctx: ServerContext,
job: JobContext,
_server: ServerContext,
_stage: StageParams,
params: ImageParams,
source_image: Image.Image,
@ -34,7 +38,7 @@ def blend_img2img(
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
params.model, params.scheduler, job.get_device())
rng = np.random.RandomState(params.seed)

View File

@ -5,6 +5,9 @@ from logging import getLogger
from PIL import Image
from typing import Callable, Tuple
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
@ -38,7 +41,8 @@ logger = getLogger(__name__)
def blend_inpaint(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
@ -65,9 +69,9 @@ def blend_inpaint(
mask_filter=mask_filter)
if is_debug():
save_image(ctx, 'last-source.png', source_image)
save_image(ctx, 'last-mask.png', mask_image)
save_image(ctx, 'last-noise.png', noise_image)
save_image(server, 'last-source.png', source_image)
save_image(server, 'last-mask.png', mask_image)
save_image(server, 'last-noise.png', noise_image)
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
left, top, tile = dims
@ -75,11 +79,11 @@ def blend_inpaint(
mask = mask_image.crop((left, top, left + tile, top + tile))
if is_debug():
save_image(ctx, 'tile-source.png', image)
save_image(ctx, 'tile-mask.png', mask)
save_image(server, 'tile-source.png', image)
save_image(server, 'tile-mask.png', mask)
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
params.model, params.provider, params.scheduler)
params.model, params.scheduler, job.get_device())
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)

View File

@ -5,6 +5,9 @@ from PIL import Image
from realesrgan import RealESRGANer
from typing import Optional
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
@ -60,7 +63,8 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
def correct_gfpgan(
ctx: ServerContext,
_job: JobContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
@ -74,7 +78,7 @@ def correct_gfpgan(
return source_image
logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model)
gfpgan = load_gfpgan(ctx, upscale, upsampler=upsampler)
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
output = np.array(source_image)
_, _, output = gfpgan.enhance(

View File

@ -1,6 +1,9 @@
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
@ -16,6 +19,7 @@ logger = getLogger(__name__)
def persist_disk(
_job: JobContext,
ctx: ServerContext,
_stage: StageParams,
_params: ImageParams,

View File

@ -5,6 +5,9 @@ from io import BytesIO
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,

View File

@ -1,6 +1,9 @@
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,

View File

@ -1,6 +1,9 @@
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,

View File

@ -2,6 +2,9 @@ from logging import getLogger
from PIL import Image
from typing import Callable
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Size,

View File

@ -4,6 +4,9 @@ from diffusers import (
from logging import getLogger
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
load_pipeline,
@ -23,7 +26,8 @@ logger = getLogger(__name__)
def source_txt2img(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source_image: Image.Image,
@ -39,7 +43,7 @@ def source_txt2img(
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler)
params.model, params.scheduler, job.get_device())
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)

View File

@ -5,6 +5,9 @@ from logging import getLogger
from PIL import Image, ImageDraw
from typing import Callable, Tuple
from ..device_pool import (
JobContext,
)
from ..diffusion.load import (
get_latents_from_seed,
get_tile_latents,

View File

@ -4,10 +4,14 @@ from os import path
from PIL import Image
from realesrgan import RealESRGANer
from ..device_pool import (
JobContext,
)
from ..onnx import (
OnnxNet,
)
from ..params import (
DeviceParams,
ImageParams,
StageParams,
UpscaleParams,
@ -25,7 +29,7 @@ last_pipeline_instance = None
last_pipeline_params = (None, None)
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0):
global last_pipeline_instance
global last_pipeline_params
@ -41,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
# use ONNX acceleration, if available
if params.format == 'onnx':
model = OnnxNet(ctx, model_file, provider=params.provider)
model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options)
elif params.format == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
@ -76,7 +80,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
def upscale_resrgan(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
@ -87,7 +92,7 @@ def upscale_resrgan(
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale)
output = np.array(source_image)
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
output, _ = upsampler.enhance(output, outscale=upscale.outscale)

View File

@ -5,10 +5,14 @@ from logging import getLogger
from os import path
from PIL import Image
from ..device_pool import (
JobContext,
)
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
OnnxStableDiffusionUpscalePipeline,
)
from ..params import (
DeviceParams,
ImageParams,
StageParams,
UpscaleParams,
@ -27,7 +31,7 @@ last_pipeline_instance = None
last_pipeline_params = (None, None)
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams):
global last_pipeline_instance
global last_pipeline_params
@ -39,11 +43,11 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
return last_pipeline_instance
if upscale.format == 'onnx':
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, upscale.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider)
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
else:
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, upscale.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider)
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
last_pipeline_instance = pipeline
last_pipeline_params = cache_params
@ -53,7 +57,8 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
def upscale_stable_diffusion(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
_stage: StageParams,
params: ImageParams,
source: Image.Image,
@ -65,7 +70,7 @@ def upscale_stable_diffusion(
prompt = prompt or params.prompt
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
pipeline = load_stable_diffusion(ctx, upscale)
pipeline = load_stable_diffusion(server, upscale, job.get_device())
generator = torch.manual_seed(params.seed)
return pipeline(

View File

@ -133,6 +133,9 @@ class DevicePoolExecutor:
return True
else:
job.set_cancel()
return True
return False
def done(self, key: str) -> Tuple[bool, int]:
for job in self.jobs:

View File

@ -5,6 +5,7 @@ from logging import getLogger
from typing import Any, Optional, Tuple
from ..params import (
DeviceParams,
Size,
)
from ..utils import (
@ -45,12 +46,12 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np
return full_latents[:, :, y:yt, x:xt]
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams):
global last_pipeline_instance
global last_pipeline_scheduler
global last_pipeline_options
options = (pipeline, model, provider)
options = (pipeline, model, device.provider)
if last_pipeline_instance != None and last_pipeline_options == options:
logger.debug('reusing existing diffusion pipeline')
pipe = last_pipeline_instance
@ -61,11 +62,18 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
run_gc()
logger.debug('loading new diffusion pipeline from %s', model)
scheduler = scheduler.from_pretrained(
model,
provider=device.provider,
sess_options=device.options,
subfolder='scheduler',
)
pipe = pipeline.from_pretrained(
model,
provider=provider,
provider=device.provider,
safety_checker=None,
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
scheduler=scheduler,
sess_options=device.options,
)
if device is not None and hasattr(pipe, 'to'):
@ -78,7 +86,11 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
if last_pipeline_scheduler != scheduler:
logger.debug('loading new diffusion scheduler')
scheduler = scheduler.from_pretrained(
model, subfolder='scheduler')
model,
provider=device.provider,
sess_options=device.options,
subfolder='scheduler',
)
if device is not None and hasattr(scheduler, 'to'):
scheduler = scheduler.to(device)

View File

@ -48,9 +48,8 @@ def run_txt2img_pipeline(
output: str,
upscale: UpscaleParams
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, device.provider, params.scheduler, device=device.torch_device())
params.model, params.scheduler, job.get_device())
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
@ -90,9 +89,8 @@ def run_img2img_pipeline(
source_image: Image.Image,
strength: float,
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, device.provider, params.scheduler, device=device.torch_device())
params.model, params.scheduler, job.get_device())
rng = np.random.RandomState(params.seed)

View File

@ -1,6 +1,6 @@
from onnxruntime import InferenceSession
from os import path
from typing import Any
from typing import Any, Optional
import numpy as np
import torch
@ -43,13 +43,19 @@ class OnnxNet():
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
'''
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
def __init__(
self,
server: ServerContext,
model: str,
provider: str = 'DmlExecutionProvider',
sess_options: Optional[dict] = None,
) -> None:
'''
TODO: get platform provider from request params
'''
model_path = path.join(ctx.model_path, model)
model_path = path.join(server.model_path, model)
self.session = InferenceSession(
model_path, providers=[provider])
model_path, providers=[provider], sess_options=sess_options)
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name

View File

@ -79,7 +79,6 @@ class ImageParams:
def __init__(
self,
model: str,
provider: str,
scheduler: Any,
prompt: str,
negative_prompt: Optional[str],
@ -88,7 +87,6 @@ class ImageParams:
seed: int
) -> None:
self.model = model
self.provider = provider
self.scheduler = scheduler
self.prompt = prompt
self.negative_prompt = negative_prompt
@ -96,10 +94,9 @@ class ImageParams:
self.steps = steps
self.seed = seed
def tojson(self) -> Dict[str, Param]:
def tojson(self) -> Dict[str, Optional[Param]]:
return {
'model': self.model,
'provider': self.provider,
'scheduler': self.scheduler.__name__,
'seed': self.seed,
'prompt': self.prompt,
@ -130,7 +127,6 @@ class UpscaleParams():
def __init__(
self,
upscale_model: str,
provider: str,
correction_model: Optional[str] = None,
denoise: float = 0.5,
faces=True,
@ -143,7 +139,6 @@ class UpscaleParams():
tile_pad: int = 10,
) -> None:
self.upscale_model = upscale_model
self.provider = provider
self.correction_model = correction_model
self.denoise = denoise
self.faces = faces
@ -158,7 +153,6 @@ class UpscaleParams():
def rescale(self, scale: int):
return UpscaleParams(
self.upscale_model,
self.provider,
correction_model=self.correction_model,
denoise=self.denoise,
faces=self.faces,

View File

@ -226,7 +226,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
params = ImageParams(model_path, device.provider, scheduler, prompt,
params = ImageParams(model_path, scheduler, prompt,
negative_prompt, cfg, steps, seed)
size = Size(width, height)
return (device, params, size)
@ -245,7 +245,7 @@ def border_from_request() -> Border:
return Border(left, right, top, bottom)
def upscale_from_request(provider: str) -> UpscaleParams:
def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
@ -257,7 +257,6 @@ def upscale_from_request(provider: str) -> UpscaleParams:
return UpscaleParams(
upscaling,
provider,
correction_model=correction,
denoise=denoise,
faces=faces,
@ -327,7 +326,7 @@ def load_platforms():
if platform_providers[potential] in providers and potential not in context.block_platforms:
if potential == 'cuda':
for i in range(torch.cuda.device_count()):
available_platforms.append(DeviceParams('%s:%s' % (potential, i), platform_providers[potential], {
available_platforms.append(DeviceParams(potential, platform_providers[potential], {
'device_id': i,
}))
else:
@ -344,7 +343,6 @@ def load_platforms():
return -1
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
logger.info('available acceleration platforms: %s',
@ -456,7 +454,7 @@ def img2img():
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
device, params, size = pipeline_from_request()
upscale = upscale_from_request(params.provider)
upscale = upscale_from_request()
strength = get_and_clamp_float(
request.args,

View File

@ -54,6 +54,7 @@
"scandir",
"scipy",
"scrollback",
"sess",
"Singlestep",
"spacy",
"spinalcase",