diff --git a/api/dev-requirements.txt b/api/dev-requirements.txt new file mode 100644 index 00000000..c13ad788 --- /dev/null +++ b/api/dev-requirements.txt @@ -0,0 +1,6 @@ +mypy + +types-Flask-Cors +types-jsonschema +types-Pillow +types-PyYAML \ No newline at end of file diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index b01e46b0..93f48c12 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -28,6 +28,7 @@ logger = getLogger(__name__) class StageCallback(Protocol): def __call__( self, + job: JobContext, ctx: ServerContext, stage: StageParams, params: ImageParams, @@ -83,7 +84,7 @@ class ChainPipeline: stage_params.tile_size) def stage_tile(tile: Image.Image, _dims) -> Image.Image: - tile = stage_pipe(server, stage_params, params, tile, + tile = stage_pipe(job, server, stage_params, params, tile, **kwargs) if is_debug(): @@ -95,7 +96,7 @@ class ChainPipeline: image, stage_params.tile_size, stage_params.outscale, [stage_tile]) else: logger.info('image within tile size, running stage') - image = stage_pipe(server, stage_params, params, image, + image = stage_pipe(job, server, stage_params, params, image, **kwargs) logger.info('finished stage %s, result size: %sx%s', diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 9078e63c..ae6b4dcc 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -4,6 +4,9 @@ from diffusers import ( from logging import getLogger from PIL import Image +from ..device_pool import ( + JobContext, +) from ..diffusion.load import ( load_pipeline, ) @@ -21,7 +24,8 @@ logger = getLogger(__name__) def blend_img2img( - _ctx: ServerContext, + job: JobContext, + _server: ServerContext, _stage: StageParams, params: ImageParams, source_image: Image.Image, @@ -34,7 +38,7 @@ def blend_img2img( logger.info('generating image using img2img, %s steps: %s', params.steps, prompt) pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, - params.model, params.provider, params.scheduler) + params.model, params.scheduler, job.get_device()) rng = np.random.RandomState(params.seed) diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 847f1ffb..3eaccc0b 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -5,6 +5,9 @@ from logging import getLogger from PIL import Image from typing import Callable, Tuple +from ..device_pool import ( + JobContext, +) from ..diffusion.load import ( get_latents_from_seed, load_pipeline, @@ -38,7 +41,8 @@ logger = getLogger(__name__) def blend_inpaint( - ctx: ServerContext, + job: JobContext, + server: ServerContext, stage: StageParams, params: ImageParams, source_image: Image.Image, @@ -65,9 +69,9 @@ def blend_inpaint( mask_filter=mask_filter) if is_debug(): - save_image(ctx, 'last-source.png', source_image) - save_image(ctx, 'last-mask.png', mask_image) - save_image(ctx, 'last-noise.png', noise_image) + save_image(server, 'last-source.png', source_image) + save_image(server, 'last-mask.png', mask_image) + save_image(server, 'last-noise.png', noise_image) def outpaint(image: Image.Image, dims: Tuple[int, int, int]): left, top, tile = dims @@ -75,11 +79,11 @@ def blend_inpaint( mask = mask_image.crop((left, top, left + tile, top + tile)) if is_debug(): - save_image(ctx, 'tile-source.png', image) - save_image(ctx, 'tile-mask.png', mask) + save_image(server, 'tile-source.png', image) + save_image(server, 'tile-mask.png', mask) pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, - params.model, params.provider, params.scheduler) + params.model, params.scheduler, job.get_device()) latents = get_latents_from_seed(params.seed, size) rng = np.random.RandomState(params.seed) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 147d32c2..6dcbfbea 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -5,6 +5,9 @@ from PIL import Image from realesrgan import RealESRGANer from typing import Optional +from ..device_pool import ( + JobContext, +) from ..params import ( ImageParams, StageParams, @@ -60,7 +63,8 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[ def correct_gfpgan( - ctx: ServerContext, + _job: JobContext, + server: ServerContext, _stage: StageParams, _params: ImageParams, source_image: Image.Image, @@ -74,7 +78,7 @@ def correct_gfpgan( return source_image logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model) - gfpgan = load_gfpgan(ctx, upscale, upsampler=upsampler) + gfpgan = load_gfpgan(server, upscale, upsampler=upsampler) output = np.array(source_image) _, _, output = gfpgan.enhance( diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index fb426850..b9495352 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -1,6 +1,9 @@ from logging import getLogger from PIL import Image +from ..device_pool import ( + JobContext, +) from ..params import ( ImageParams, StageParams, @@ -16,6 +19,7 @@ logger = getLogger(__name__) def persist_disk( + _job: JobContext, ctx: ServerContext, _stage: StageParams, _params: ImageParams, diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index d54f4a93..1fd966fb 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -5,6 +5,9 @@ from io import BytesIO from logging import getLogger from PIL import Image +from ..device_pool import ( + JobContext, +) from ..params import ( ImageParams, StageParams, diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 3e3a5f96..ca21ae02 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -1,6 +1,9 @@ from logging import getLogger from PIL import Image +from ..device_pool import ( + JobContext, +) from ..params import ( ImageParams, Size, diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 8c3f0e19..02498daa 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -1,6 +1,9 @@ from logging import getLogger from PIL import Image +from ..device_pool import ( + JobContext, +) from ..params import ( ImageParams, Size, diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 33c6c226..8bd360dc 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -2,6 +2,9 @@ from logging import getLogger from PIL import Image from typing import Callable +from ..device_pool import ( + JobContext, +) from ..params import ( ImageParams, Size, diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 4ae06580..e72fba37 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -4,6 +4,9 @@ from diffusers import ( from logging import getLogger from PIL import Image +from ..device_pool import ( + JobContext, +) from ..diffusion.load import ( get_latents_from_seed, load_pipeline, @@ -23,7 +26,8 @@ logger = getLogger(__name__) def source_txt2img( - ctx: ServerContext, + job: JobContext, + server: ServerContext, stage: StageParams, params: ImageParams, source_image: Image.Image, @@ -39,7 +43,7 @@ def source_txt2img( logger.warn('a source image was passed to a txt2img stage, but will be discarded') pipe = load_pipeline(OnnxStableDiffusionPipeline, - params.model, params.provider, params.scheduler) + params.model, params.scheduler, job.get_device()) latents = get_latents_from_seed(params.seed, size) rng = np.random.RandomState(params.seed) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 254f2c03..8e1f29c2 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -5,6 +5,9 @@ from logging import getLogger from PIL import Image, ImageDraw from typing import Callable, Tuple +from ..device_pool import ( + JobContext, +) from ..diffusion.load import ( get_latents_from_seed, get_tile_latents, diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index be862d95..3a030dbd 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -4,10 +4,14 @@ from os import path from PIL import Image from realesrgan import RealESRGANer +from ..device_pool import ( + JobContext, +) from ..onnx import ( OnnxNet, ) from ..params import ( + DeviceParams, ImageParams, StageParams, UpscaleParams, @@ -25,7 +29,7 @@ last_pipeline_instance = None last_pipeline_params = (None, None) -def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): +def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0): global last_pipeline_instance global last_pipeline_params @@ -41,7 +45,7 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): # use ONNX acceleration, if available if params.format == 'onnx': - model = OnnxNet(ctx, model_file, provider=params.provider) + model = OnnxNet(ctx, model_file, provider=device.provider, sess_options=device.options) elif params.format == 'pth': model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=params.scale) @@ -76,7 +80,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): def upscale_resrgan( - ctx: ServerContext, + job: JobContext, + server: ServerContext, stage: StageParams, _params: ImageParams, source_image: Image.Image, @@ -87,7 +92,7 @@ def upscale_resrgan( logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale) output = np.array(source_image) - upsampler = load_resrgan(ctx, upscale, tile=stage.tile_size) + upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size) output, _ = upsampler.enhance(output, outscale=upscale.outscale) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 15ce479d..86cef5fe 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -5,10 +5,14 @@ from logging import getLogger from os import path from PIL import Image +from ..device_pool import ( + JobContext, +) from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) from ..params import ( + DeviceParams, ImageParams, StageParams, UpscaleParams, @@ -27,7 +31,7 @@ last_pipeline_instance = None last_pipeline_params = (None, None) -def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams): +def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams): global last_pipeline_instance global last_pipeline_params @@ -39,11 +43,11 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams): return last_pipeline_instance if upscale.format == 'onnx': - logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, upscale.provider) - pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider) + logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider) + pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options) else: - logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, upscale.provider) - pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=upscale.provider) + logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider) + pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, sess_options=device.options) last_pipeline_instance = pipeline last_pipeline_params = cache_params @@ -53,7 +57,8 @@ def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams): def upscale_stable_diffusion( - ctx: ServerContext, + job: JobContext, + server: ServerContext, _stage: StageParams, params: ImageParams, source: Image.Image, @@ -65,7 +70,7 @@ def upscale_stable_diffusion( prompt = prompt or params.prompt logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt) - pipeline = load_stable_diffusion(ctx, upscale) + pipeline = load_stable_diffusion(server, upscale, job.get_device()) generator = torch.manual_seed(params.seed) return pipeline( diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/device_pool.py index e55665a3..7c6b051f 100644 --- a/api/onnx_web/device_pool.py +++ b/api/onnx_web/device_pool.py @@ -133,6 +133,9 @@ class DevicePoolExecutor: return True else: job.set_cancel() + return True + + return False def done(self, key: str) -> Tuple[bool, int]: for job in self.jobs: diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index a070f4e3..e7661a70 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -5,6 +5,7 @@ from logging import getLogger from typing import Any, Optional, Tuple from ..params import ( + DeviceParams, Size, ) from ..utils import ( @@ -45,12 +46,12 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np return full_latents[:, :, y:yt, x:xt] -def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, scheduler: Any, device: Optional[str] = None): +def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams): global last_pipeline_instance global last_pipeline_scheduler global last_pipeline_options - options = (pipeline, model, provider) + options = (pipeline, model, device.provider) if last_pipeline_instance != None and last_pipeline_options == options: logger.debug('reusing existing diffusion pipeline') pipe = last_pipeline_instance @@ -61,11 +62,18 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu run_gc() logger.debug('loading new diffusion pipeline from %s', model) + scheduler = scheduler.from_pretrained( + model, + provider=device.provider, + sess_options=device.options, + subfolder='scheduler', + ) pipe = pipeline.from_pretrained( model, - provider=provider, + provider=device.provider, safety_checker=None, - scheduler=scheduler.from_pretrained(model, subfolder='scheduler') + scheduler=scheduler, + sess_options=device.options, ) if device is not None and hasattr(pipe, 'to'): @@ -78,7 +86,11 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu if last_pipeline_scheduler != scheduler: logger.debug('loading new diffusion scheduler') scheduler = scheduler.from_pretrained( - model, subfolder='scheduler') + model, + provider=device.provider, + sess_options=device.options, + subfolder='scheduler', + ) if device is not None and hasattr(scheduler, 'to'): scheduler = scheduler.to(device) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 1b44c700..ac7ad6fc 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -48,9 +48,8 @@ def run_txt2img_pipeline( output: str, upscale: UpscaleParams ) -> None: - device = job.get_device() pipe = load_pipeline(OnnxStableDiffusionPipeline, - params.model, device.provider, params.scheduler, device=device.torch_device()) + params.model, params.scheduler, job.get_device()) latents = get_latents_from_seed(params.seed, size) rng = np.random.RandomState(params.seed) @@ -90,9 +89,8 @@ def run_img2img_pipeline( source_image: Image.Image, strength: float, ) -> None: - device = job.get_device() pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, - params.model, device.provider, params.scheduler, device=device.torch_device()) + params.model, params.scheduler, job.get_device()) rng = np.random.RandomState(params.seed) diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 90c22114..9b0d14f9 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -1,6 +1,6 @@ from onnxruntime import InferenceSession from os import path -from typing import Any +from typing import Any, Optional import numpy as np import torch @@ -43,13 +43,19 @@ class OnnxNet(): Provides the RRDBNet interface using an ONNX session for DirectML acceleration. ''' - def __init__(self, ctx: ServerContext, model: str, provider='DmlExecutionProvider') -> None: + def __init__( + self, + server: ServerContext, + model: str, + provider: str = 'DmlExecutionProvider', + sess_options: Optional[dict] = None, + ) -> None: ''' TODO: get platform provider from request params ''' - model_path = path.join(ctx.model_path, model) + model_path = path.join(server.model_path, model) self.session = InferenceSession( - model_path, providers=[provider]) + model_path, providers=[provider], sess_options=sess_options) def __call__(self, image: Any) -> Any: input_name = self.session.get_inputs()[0].name diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 296c8cbc..a00f2a1f 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -79,7 +79,6 @@ class ImageParams: def __init__( self, model: str, - provider: str, scheduler: Any, prompt: str, negative_prompt: Optional[str], @@ -88,7 +87,6 @@ class ImageParams: seed: int ) -> None: self.model = model - self.provider = provider self.scheduler = scheduler self.prompt = prompt self.negative_prompt = negative_prompt @@ -96,10 +94,9 @@ class ImageParams: self.steps = steps self.seed = seed - def tojson(self) -> Dict[str, Param]: + def tojson(self) -> Dict[str, Optional[Param]]: return { 'model': self.model, - 'provider': self.provider, 'scheduler': self.scheduler.__name__, 'seed': self.seed, 'prompt': self.prompt, @@ -130,7 +127,6 @@ class UpscaleParams(): def __init__( self, upscale_model: str, - provider: str, correction_model: Optional[str] = None, denoise: float = 0.5, faces=True, @@ -143,7 +139,6 @@ class UpscaleParams(): tile_pad: int = 10, ) -> None: self.upscale_model = upscale_model - self.provider = provider self.correction_model = correction_model self.denoise = denoise self.faces = faces @@ -158,7 +153,6 @@ class UpscaleParams(): def rescale(self, scale: int): return UpscaleParams( self.upscale_model, - self.provider, correction_model=self.correction_model, denoise=self.denoise, faces=self.faces, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 088ed816..779c7741 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -226,7 +226,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt) - params = ImageParams(model_path, device.provider, scheduler, prompt, + params = ImageParams(model_path, scheduler, prompt, negative_prompt, cfg, steps, seed) size = Size(width, height) return (device, params, size) @@ -245,7 +245,7 @@ def border_from_request() -> Border: return Border(left, right, top, bottom) -def upscale_from_request(provider: str) -> UpscaleParams: +def upscale_from_request() -> UpscaleParams: denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0) scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1) outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1) @@ -257,7 +257,6 @@ def upscale_from_request(provider: str) -> UpscaleParams: return UpscaleParams( upscaling, - provider, correction_model=correction, denoise=denoise, faces=faces, @@ -327,7 +326,7 @@ def load_platforms(): if platform_providers[potential] in providers and potential not in context.block_platforms: if potential == 'cuda': for i in range(torch.cuda.device_count()): - available_platforms.append(DeviceParams('%s:%s' % (potential, i), platform_providers[potential], { + available_platforms.append(DeviceParams(potential, platform_providers[potential], { 'device_id': i, })) else: @@ -344,7 +343,6 @@ def load_platforms(): return -1 - available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last)) logger.info('available acceleration platforms: %s', @@ -456,7 +454,7 @@ def img2img(): source_image = Image.open(BytesIO(source_file.read())).convert('RGB') device, params, size = pipeline_from_request() - upscale = upscale_from_request(params.provider) + upscale = upscale_from_request() strength = get_and_clamp_float( request.args, diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace index 6447e815..42dd5d87 100644 --- a/onnx-web.code-workspace +++ b/onnx-web.code-workspace @@ -54,6 +54,7 @@ "scandir", "scipy", "scrollback", + "sess", "Singlestep", "spacy", "spinalcase",