fix(api): set CUDA device in ORT session
This commit is contained in:
parent
d636ce3eef
commit
04a2faffd9
|
@ -0,0 +1,6 @@
|
||||||
|
mypy
|
||||||
|
|
||||||
|
types-Flask-Cors
|
||||||
|
types-jsonschema
|
||||||
|
types-Pillow
|
||||||
|
types-PyYAML
|
|
@ -28,6 +28,7 @@ logger = getLogger(__name__)
|
||||||
class StageCallback(Protocol):
|
class StageCallback(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
job: JobContext,
|
||||||
ctx: ServerContext,
|
ctx: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
|
@ -83,7 +84,7 @@ class ChainPipeline:
|
||||||
stage_params.tile_size)
|
stage_params.tile_size)
|
||||||
|
|
||||||
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
|
||||||
tile = stage_pipe(server, stage_params, params, tile,
|
tile = stage_pipe(job, server, stage_params, params, tile,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
|
@ -95,7 +96,7 @@ class ChainPipeline:
|
||||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||||
else:
|
else:
|
||||||
logger.info('image within tile size, running stage')
|
logger.info('image within tile size, running stage')
|
||||||
image = stage_pipe(server, stage_params, params, image,
|
image = stage_pipe(job, server, stage_params, params, image,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
logger.info('finished stage %s, result size: %sx%s',
|
logger.info('finished stage %s, result size: %sx%s',
|
||||||
|
|
|
@ -4,6 +4,9 @@ from diffusers import (
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..diffusion.load import (
|
from ..diffusion.load import (
|
||||||
load_pipeline,
|
load_pipeline,
|
||||||
)
|
)
|
||||||
|
@ -21,7 +24,8 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_img2img(
|
def blend_img2img(
|
||||||
_ctx: ServerContext,
|
job: JobContext,
|
||||||
|
_server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -34,7 +38,7 @@ def blend_img2img(
|
||||||
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
|
logger.info('generating image using img2img, %s steps: %s', params.steps, prompt)
|
||||||
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||||
params.model, params.provider, params.scheduler)
|
params.model, params.scheduler, job.get_device())
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,9 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..diffusion.load import (
|
from ..diffusion.load import (
|
||||||
get_latents_from_seed,
|
get_latents_from_seed,
|
||||||
load_pipeline,
|
load_pipeline,
|
||||||
|
@ -38,7 +41,8 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def blend_inpaint(
|
def blend_inpaint(
|
||||||
ctx: ServerContext,
|
job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -65,9 +69,9 @@ def blend_inpaint(
|
||||||
mask_filter=mask_filter)
|
mask_filter=mask_filter)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(ctx, 'last-source.png', source_image)
|
save_image(server, 'last-source.png', source_image)
|
||||||
save_image(ctx, 'last-mask.png', mask_image)
|
save_image(server, 'last-mask.png', mask_image)
|
||||||
save_image(ctx, 'last-noise.png', noise_image)
|
save_image(server, 'last-noise.png', noise_image)
|
||||||
|
|
||||||
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
def outpaint(image: Image.Image, dims: Tuple[int, int, int]):
|
||||||
left, top, tile = dims
|
left, top, tile = dims
|
||||||
|
@ -75,11 +79,11 @@ def blend_inpaint(
|
||||||
mask = mask_image.crop((left, top, left + tile, top + tile))
|
mask = mask_image.crop((left, top, left + tile, top + tile))
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
save_image(ctx, 'tile-source.png', image)
|
save_image(server, 'tile-source.png', image)
|
||||||
save_image(ctx, 'tile-mask.png', mask)
|
save_image(server, 'tile-mask.png', mask)
|
||||||
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline,
|
||||||
params.model, params.provider, params.scheduler)
|
params.model, params.scheduler, job.get_device())
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
|
@ -5,6 +5,9 @@ from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..params import (
|
from ..params import (
|
||||||
ImageParams,
|
ImageParams,
|
||||||
StageParams,
|
StageParams,
|
||||||
|
@ -60,7 +63,8 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[
|
||||||
|
|
||||||
|
|
||||||
def correct_gfpgan(
|
def correct_gfpgan(
|
||||||
ctx: ServerContext,
|
_job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -74,7 +78,7 @@ def correct_gfpgan(
|
||||||
return source_image
|
return source_image
|
||||||
|
|
||||||
logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model)
|
logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model)
|
||||||
gfpgan = load_gfpgan(ctx, upscale, upsampler=upsampler)
|
gfpgan = load_gfpgan(server, upscale, upsampler=upsampler)
|
||||||
|
|
||||||
output = np.array(source_image)
|
output = np.array(source_image)
|
||||||
_, _, output = gfpgan.enhance(
|
_, _, output = gfpgan.enhance(
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..params import (
|
from ..params import (
|
||||||
ImageParams,
|
ImageParams,
|
||||||
StageParams,
|
StageParams,
|
||||||
|
@ -16,6 +19,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def persist_disk(
|
def persist_disk(
|
||||||
|
_job: JobContext,
|
||||||
ctx: ServerContext,
|
ctx: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
|
|
|
@ -5,6 +5,9 @@ from io import BytesIO
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..params import (
|
from ..params import (
|
||||||
ImageParams,
|
ImageParams,
|
||||||
StageParams,
|
StageParams,
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..params import (
|
from ..params import (
|
||||||
ImageParams,
|
ImageParams,
|
||||||
Size,
|
Size,
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..params import (
|
from ..params import (
|
||||||
ImageParams,
|
ImageParams,
|
||||||
Size,
|
Size,
|
||||||
|
|
|
@ -2,6 +2,9 @@ from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..params import (
|
from ..params import (
|
||||||
ImageParams,
|
ImageParams,
|
||||||
Size,
|
Size,
|
||||||
|
|
|
@ -4,6 +4,9 @@ from diffusers import (
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..diffusion.load import (
|
from ..diffusion.load import (
|
||||||
get_latents_from_seed,
|
get_latents_from_seed,
|
||||||
load_pipeline,
|
load_pipeline,
|
||||||
|
@ -23,7 +26,8 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def source_txt2img(
|
def source_txt2img(
|
||||||
ctx: ServerContext,
|
job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -39,7 +43,7 @@ def source_txt2img(
|
||||||
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
|
logger.warn('a source image was passed to a txt2img stage, but will be discarded')
|
||||||
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||||
params.model, params.provider, params.scheduler)
|
params.model, params.scheduler, job.get_device())
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
|
@ -5,6 +5,9 @@ from logging import getLogger
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..diffusion.load import (
|
from ..diffusion.load import (
|
||||||
get_latents_from_seed,
|
get_latents_from_seed,
|
||||||
get_tile_latents,
|
get_tile_latents,
|
||||||
|
|
|
@ -4,10 +4,14 @@ from os import path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..onnx import (
|
from ..onnx import (
|
||||||
OnnxNet,
|
OnnxNet,
|
||||||
)
|
)
|
||||||
from ..params import (
|
from ..params import (
|
||||||
|
DeviceParams,
|
||||||
ImageParams,
|
ImageParams,
|
||||||
StageParams,
|
StageParams,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
|
@ -25,7 +29,7 @@ last_pipeline_instance = None
|
||||||
last_pipeline_params = (None, None)
|
last_pipeline_params = (None, None)
|
||||||
|
|
||||||
|
|
||||||
def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_params
|
global last_pipeline_params
|
||||||
|
|
||||||
|
@ -41,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||||
|
|
||||||
# use ONNX acceleration, if available
|
# use ONNX acceleration, if available
|
||||||
if params.format == 'onnx':
|
if params.format == 'onnx':
|
||||||
model = OnnxNet(ctx, model_file, provider=params.provider)
|
model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options)
|
||||||
elif params.format == 'pth':
|
elif params.format == 'pth':
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||||
|
@ -76,7 +80,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||||
|
|
||||||
|
|
||||||
def upscale_resrgan(
|
def upscale_resrgan(
|
||||||
ctx: ServerContext,
|
job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
stage: StageParams,
|
stage: StageParams,
|
||||||
_params: ImageParams,
|
_params: ImageParams,
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
|
@ -87,7 +92,7 @@ def upscale_resrgan(
|
||||||
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale)
|
logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale)
|
||||||
|
|
||||||
output = np.array(source_image)
|
output = np.array(source_image)
|
||||||
upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size)
|
upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size)
|
||||||
|
|
||||||
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
output, _ = upsampler.enhance(output, outscale=upscale.outscale)
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,14 @@ from logging import getLogger
|
||||||
from os import path
|
from os import path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..device_pool import (
|
||||||
|
JobContext,
|
||||||
|
)
|
||||||
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
from ..diffusion.pipeline_onnx_stable_diffusion_upscale import (
|
||||||
OnnxStableDiffusionUpscalePipeline,
|
OnnxStableDiffusionUpscalePipeline,
|
||||||
)
|
)
|
||||||
from ..params import (
|
from ..params import (
|
||||||
|
DeviceParams,
|
||||||
ImageParams,
|
ImageParams,
|
||||||
StageParams,
|
StageParams,
|
||||||
UpscaleParams,
|
UpscaleParams,
|
||||||
|
@ -27,7 +31,7 @@ last_pipeline_instance = None
|
||||||
last_pipeline_params = (None, None)
|
last_pipeline_params = (None, None)
|
||||||
|
|
||||||
|
|
||||||
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_params
|
global last_pipeline_params
|
||||||
|
|
||||||
|
@ -39,11 +43,11 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
||||||
return last_pipeline_instance
|
return last_pipeline_instance
|
||||||
|
|
||||||
if upscale.format == 'onnx':
|
if upscale.format == 'onnx':
|
||||||
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, upscale.provider)
|
logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider)
|
||||||
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider)
|
pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
|
||||||
else:
|
else:
|
||||||
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, upscale.provider)
|
logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider)
|
||||||
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider)
|
pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options)
|
||||||
|
|
||||||
last_pipeline_instance = pipeline
|
last_pipeline_instance = pipeline
|
||||||
last_pipeline_params = cache_params
|
last_pipeline_params = cache_params
|
||||||
|
@ -53,7 +57,8 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams):
|
||||||
|
|
||||||
|
|
||||||
def upscale_stable_diffusion(
|
def upscale_stable_diffusion(
|
||||||
ctx: ServerContext,
|
job: JobContext,
|
||||||
|
server: ServerContext,
|
||||||
_stage: StageParams,
|
_stage: StageParams,
|
||||||
params: ImageParams,
|
params: ImageParams,
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
|
@ -65,7 +70,7 @@ def upscale_stable_diffusion(
|
||||||
prompt = prompt or params.prompt
|
prompt = prompt or params.prompt
|
||||||
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
|
logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt)
|
||||||
|
|
||||||
pipeline = load_stable_diffusion(ctx, upscale)
|
pipeline = load_stable_diffusion(server, upscale, job.get_device())
|
||||||
generator = torch.manual_seed(params.seed)
|
generator = torch.manual_seed(params.seed)
|
||||||
|
|
||||||
return pipeline(
|
return pipeline(
|
||||||
|
|
|
@ -133,6 +133,9 @@ class DevicePoolExecutor:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
job.set_cancel()
|
job.set_cancel()
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def done(self, key: str) -> Tuple[bool, int]:
|
def done(self, key: str) -> Tuple[bool, int]:
|
||||||
for job in self.jobs:
|
for job in self.jobs:
|
||||||
|
|
|
@ -5,6 +5,7 @@ from logging import getLogger
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
from ..params import (
|
from ..params import (
|
||||||
|
DeviceParams,
|
||||||
Size,
|
Size,
|
||||||
)
|
)
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
|
@ -45,12 +46,12 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np
|
||||||
return full_latents[:, :, y:yt, x:xt]
|
return full_latents[:, :, y:yt, x:xt]
|
||||||
|
|
||||||
|
|
||||||
def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None):
|
def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams):
|
||||||
global last_pipeline_instance
|
global last_pipeline_instance
|
||||||
global last_pipeline_scheduler
|
global last_pipeline_scheduler
|
||||||
global last_pipeline_options
|
global last_pipeline_options
|
||||||
|
|
||||||
options = (pipeline, model, provider)
|
options = (pipeline, model, device.provider)
|
||||||
if last_pipeline_instance != None and last_pipeline_options == options:
|
if last_pipeline_instance != None and last_pipeline_options == options:
|
||||||
logger.debug('reusing existing diffusion pipeline')
|
logger.debug('reusing existing diffusion pipeline')
|
||||||
pipe = last_pipeline_instance
|
pipe = last_pipeline_instance
|
||||||
|
@ -61,11 +62,18 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
||||||
run_gc()
|
run_gc()
|
||||||
|
|
||||||
logger.debug('loading new diffusion pipeline from %s', model)
|
logger.debug('loading new diffusion pipeline from %s', model)
|
||||||
|
scheduler = scheduler.from_pretrained(
|
||||||
|
model,
|
||||||
|
provider=device.provider,
|
||||||
|
sess_options=device.options,
|
||||||
|
subfolder='scheduler',
|
||||||
|
)
|
||||||
pipe = pipeline.from_pretrained(
|
pipe = pipeline.from_pretrained(
|
||||||
model,
|
model,
|
||||||
provider=provider,
|
provider=device.provider,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
scheduler=scheduler.from_pretrained(model, subfolder='scheduler')
|
scheduler=scheduler,
|
||||||
|
sess_options=device.options,
|
||||||
)
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(pipe, 'to'):
|
if device is not None and hasattr(pipe, 'to'):
|
||||||
|
@ -78,7 +86,11 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu
|
||||||
if last_pipeline_scheduler != scheduler:
|
if last_pipeline_scheduler != scheduler:
|
||||||
logger.debug('loading new diffusion scheduler')
|
logger.debug('loading new diffusion scheduler')
|
||||||
scheduler = scheduler.from_pretrained(
|
scheduler = scheduler.from_pretrained(
|
||||||
model, subfolder='scheduler')
|
model,
|
||||||
|
provider=device.provider,
|
||||||
|
sess_options=device.options,
|
||||||
|
subfolder='scheduler',
|
||||||
|
)
|
||||||
|
|
||||||
if device is not None and hasattr(scheduler, 'to'):
|
if device is not None and hasattr(scheduler, 'to'):
|
||||||
scheduler = scheduler.to(device)
|
scheduler = scheduler.to(device)
|
||||||
|
|
|
@ -48,9 +48,8 @@ def run_txt2img_pipeline(
|
||||||
output: str,
|
output: str,
|
||||||
upscale: UpscaleParams
|
upscale: UpscaleParams
|
||||||
) -> None:
|
) -> None:
|
||||||
device = job.get_device()
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||||
params.model, device.provider, params.scheduler, device=device.torch_device())
|
params.model, params.scheduler, job.get_device())
|
||||||
|
|
||||||
latents = get_latents_from_seed(params.seed, size)
|
latents = get_latents_from_seed(params.seed, size)
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
@ -90,9 +89,8 @@ def run_img2img_pipeline(
|
||||||
source_image: Image.Image,
|
source_image: Image.Image,
|
||||||
strength: float,
|
strength: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
device = job.get_device()
|
|
||||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||||
params.model, device.provider, params.scheduler, device=device.torch_device())
|
params.model, params.scheduler, job.get_device())
|
||||||
|
|
||||||
rng = np.random.RandomState(params.seed)
|
rng = np.random.RandomState(params.seed)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from onnxruntime import InferenceSession
|
from onnxruntime import InferenceSession
|
||||||
from os import path
|
from os import path
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -43,13 +43,19 @@ class OnnxNet():
|
||||||
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
Provides the RRDBNet interface using an ONNX session for DirectML acceleration.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
server: ServerContext,
|
||||||
|
model: str,
|
||||||
|
provider: str = 'DmlExecutionProvider',
|
||||||
|
sess_options: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
'''
|
'''
|
||||||
TODO: get platform provider from request params
|
TODO: get platform provider from request params
|
||||||
'''
|
'''
|
||||||
model_path = path.join(ctx.model_path, model)
|
model_path = path.join(server.model_path, model)
|
||||||
self.session = InferenceSession(
|
self.session = InferenceSession(
|
||||||
model_path, providers=[provider])
|
model_path, providers=[provider], sess_options=sess_options)
|
||||||
|
|
||||||
def __call__(self, image: Any) -> Any:
|
def __call__(self, image: Any) -> Any:
|
||||||
input_name = self.session.get_inputs()[0].name
|
input_name = self.session.get_inputs()[0].name
|
||||||
|
|
|
@ -79,7 +79,6 @@ class ImageParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
provider: str,
|
|
||||||
scheduler: Any,
|
scheduler: Any,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: Optional[str],
|
negative_prompt: Optional[str],
|
||||||
|
@ -88,7 +87,6 @@ class ImageParams:
|
||||||
seed: int
|
seed: int
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.provider = provider
|
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.negative_prompt = negative_prompt
|
self.negative_prompt = negative_prompt
|
||||||
|
@ -96,10 +94,9 @@ class ImageParams:
|
||||||
self.steps = steps
|
self.steps = steps
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
def tojson(self) -> Dict[str, Param]:
|
def tojson(self) -> Dict[str, Optional[Param]]:
|
||||||
return {
|
return {
|
||||||
'model': self.model,
|
'model': self.model,
|
||||||
'provider': self.provider,
|
|
||||||
'scheduler': self.scheduler.__name__,
|
'scheduler': self.scheduler.__name__,
|
||||||
'seed': self.seed,
|
'seed': self.seed,
|
||||||
'prompt': self.prompt,
|
'prompt': self.prompt,
|
||||||
|
@ -130,7 +127,6 @@ class UpscaleParams():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
upscale_model: str,
|
upscale_model: str,
|
||||||
provider: str,
|
|
||||||
correction_model: Optional[str] = None,
|
correction_model: Optional[str] = None,
|
||||||
denoise: float = 0.5,
|
denoise: float = 0.5,
|
||||||
faces=True,
|
faces=True,
|
||||||
|
@ -143,7 +139,6 @@ class UpscaleParams():
|
||||||
tile_pad: int = 10,
|
tile_pad: int = 10,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.upscale_model = upscale_model
|
self.upscale_model = upscale_model
|
||||||
self.provider = provider
|
|
||||||
self.correction_model = correction_model
|
self.correction_model = correction_model
|
||||||
self.denoise = denoise
|
self.denoise = denoise
|
||||||
self.faces = faces
|
self.faces = faces
|
||||||
|
@ -158,7 +153,6 @@ class UpscaleParams():
|
||||||
def rescale(self, scale: int):
|
def rescale(self, scale: int):
|
||||||
return UpscaleParams(
|
return UpscaleParams(
|
||||||
self.upscale_model,
|
self.upscale_model,
|
||||||
self.provider,
|
|
||||||
correction_model=self.correction_model,
|
correction_model=self.correction_model,
|
||||||
denoise=self.denoise,
|
denoise=self.denoise,
|
||||||
faces=self.faces,
|
faces=self.faces,
|
||||||
|
|
|
@ -226,7 +226,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]:
|
||||||
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s",
|
||||||
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
|
user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt)
|
||||||
|
|
||||||
params = ImageParams(model_path, device.provider, scheduler, prompt,
|
params = ImageParams(model_path, scheduler, prompt,
|
||||||
negative_prompt, cfg, steps, seed)
|
negative_prompt, cfg, steps, seed)
|
||||||
size = Size(width, height)
|
size = Size(width, height)
|
||||||
return (device, params, size)
|
return (device, params, size)
|
||||||
|
@ -245,7 +245,7 @@ def border_from_request() -> Border:
|
||||||
return Border(left, right, top, bottom)
|
return Border(left, right, top, bottom)
|
||||||
|
|
||||||
|
|
||||||
def upscale_from_request(provider: str) -> UpscaleParams:
|
def upscale_from_request() -> UpscaleParams:
|
||||||
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
|
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
|
||||||
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
|
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
|
||||||
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
|
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
|
||||||
|
@ -257,7 +257,6 @@ def upscale_from_request(provider: str) -> UpscaleParams:
|
||||||
|
|
||||||
return UpscaleParams(
|
return UpscaleParams(
|
||||||
upscaling,
|
upscaling,
|
||||||
provider,
|
|
||||||
correction_model=correction,
|
correction_model=correction,
|
||||||
denoise=denoise,
|
denoise=denoise,
|
||||||
faces=faces,
|
faces=faces,
|
||||||
|
@ -327,7 +326,7 @@ def load_platforms():
|
||||||
if platform_providers[potential] in providers and potential not in context.block_platforms:
|
if platform_providers[potential] in providers and potential not in context.block_platforms:
|
||||||
if potential == 'cuda':
|
if potential == 'cuda':
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
available_platforms.append(DeviceParams('%s:%s' % (potential, i), platform_providers[potential], {
|
available_platforms.append(DeviceParams(potential, platform_providers[potential], {
|
||||||
'device_id': i,
|
'device_id': i,
|
||||||
}))
|
}))
|
||||||
else:
|
else:
|
||||||
|
@ -344,7 +343,6 @@ def load_platforms():
|
||||||
|
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
|
||||||
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
|
available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last))
|
||||||
|
|
||||||
logger.info('available acceleration platforms: %s',
|
logger.info('available acceleration platforms: %s',
|
||||||
|
@ -456,7 +454,7 @@ def img2img():
|
||||||
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
source_image = Image.open(BytesIO(source_file.read())).convert('RGB')
|
||||||
|
|
||||||
device, params, size = pipeline_from_request()
|
device, params, size = pipeline_from_request()
|
||||||
upscale = upscale_from_request(params.provider)
|
upscale = upscale_from_request()
|
||||||
|
|
||||||
strength = get_and_clamp_float(
|
strength = get_and_clamp_float(
|
||||||
request.args,
|
request.args,
|
||||||
|
|
|
@ -54,6 +54,7 @@
|
||||||
"scandir",
|
"scandir",
|
||||||
"scipy",
|
"scipy",
|
||||||
"scrollback",
|
"scrollback",
|
||||||
|
"sess",
|
||||||
"Singlestep",
|
"Singlestep",
|
||||||
"spacy",
|
"spacy",
|
||||||
"spinalcase",
|
"spinalcase",
|
||||||
|
|
Loading…
Reference in New Issue