From 54dd34d211464f4b26b3beb93ed0c777c3c902d4 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 5 Feb 2023 07:53:26 -0600 Subject: [PATCH] lint(api): apply black and isort style --- api/Makefile | 13 + api/dev-requirements.txt | 3 + api/onnx_web/__init__.py | 68 +- api/onnx_web/chain/__init__.py | 55 +- api/onnx_web/chain/base.py | 105 +-- api/onnx_web/chain/blend_img2img.py | 38 +- api/onnx_web/chain/blend_inpaint.py | 73 +-- api/onnx_web/chain/correct_codeformer.py | 76 ++- api/onnx_web/chain/correct_gfpgan.py | 58 +- api/onnx_web/chain/persist_disk.py | 24 +- api/onnx_web/chain/persist_s3.py | 29 +- api/onnx_web/chain/reduce_crop.py | 24 +- api/onnx_web/chain/reduce_thumbnail.py | 21 +- api/onnx_web/chain/source_noise.py | 29 +- api/onnx_web/chain/source_txt2img.py | 40 +- api/onnx_web/chain/upscale_outpaint.py | 77 +-- api/onnx_web/chain/upscale_resrgan.py | 76 +-- .../chain/upscale_stable_diffusion.py | 56 +- api/onnx_web/chain/utils.py | 13 +- api/onnx_web/convert.py | 317 +++++---- api/onnx_web/device_pool.py | 75 ++- api/onnx_web/diffusion/load.py | 55 +- .../pipeline_onnx_stable_diffusion_upscale.py | 117 ++-- api/onnx_web/diffusion/run.py | 84 +-- api/onnx_web/image.py | 126 ++-- api/onnx_web/logging.py | 13 +- api/onnx_web/onnx/__init__.py | 5 +- api/onnx_web/onnx/onnx_net.py | 24 +- api/onnx_web/output.py | 52 +- api/onnx_web/params.py | 67 +- api/onnx_web/serve.py | 606 ++++++++++-------- api/onnx_web/upscale.py | 38 +- api/onnx_web/utils.py | 63 +- api/pyproject.toml | 2 + 34 files changed, 1271 insertions(+), 1251 deletions(-) create mode 100644 api/pyproject.toml diff --git a/api/Makefile b/api/Makefile index 705023a6..5f248cbd 100644 --- a/api/Makefile +++ b/api/Makefile @@ -23,3 +23,16 @@ package-dist: package-upload: twine upload dist/* + +lint-check: + black --check --preview onnx_web + isort --check-only --skip __init__.py --filter-files onnx_web + flake8 --per-file-ignores="__init__.py:F401" onnx_web + +lint-fix: + black onnx_web + isort --skip __init__.py --filter-files onnx_web + flake8 --per-file-ignores="__init__.py:F401" onnx_web + +typecheck: + mypy -m onnx_web.serve diff --git a/api/dev-requirements.txt b/api/dev-requirements.txt index c13ad788..04117cda 100644 --- a/api/dev-requirements.txt +++ b/api/dev-requirements.txt @@ -1,3 +1,6 @@ +black +flake8 +isort mypy types-Flask-Cors diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index eb1395c6..f3e3aabe 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -1,49 +1,31 @@ from . import logging - -from .chain import ( - correct_gfpgan, - upscale_resrgan, - upscale_stable_diffusion, -) -from .diffusion.load import ( - get_latents_from_seed, - load_pipeline, -) +from .chain import correct_gfpgan, upscale_resrgan, upscale_stable_diffusion +from .diffusion.load import get_latents_from_seed, load_pipeline from .diffusion.run import ( - run_img2img_pipeline, - run_inpaint_pipeline, - run_txt2img_pipeline, + run_img2img_pipeline, + run_inpaint_pipeline, + run_txt2img_pipeline, ) from .image import ( - expand_image, - mask_filter_gaussian_multiply, - mask_filter_gaussian_screen, - mask_filter_none, - noise_source_fill_edge, - noise_source_fill_mask, - noise_source_gaussian, - noise_source_histogram, - noise_source_normal, - noise_source_uniform, -) -from .params import ( - Param, - Point, - Border, - Size, - ImageParams, - StageParams, - UpscaleParams, -) -from .upscale import ( - run_upscale_correction, + expand_image, + mask_filter_gaussian_multiply, + mask_filter_gaussian_screen, + mask_filter_none, + noise_source_fill_edge, + noise_source_fill_mask, + noise_source_gaussian, + noise_source_histogram, + noise_source_normal, + noise_source_uniform, ) +from .params import Border, ImageParams, Param, Point, Size, StageParams, UpscaleParams +from .upscale import run_upscale_correction from .utils import ( - get_and_clamp_float, - get_and_clamp_int, - get_from_list, - get_from_map, - get_not_empty, - base_join, - ServerContext, -) \ No newline at end of file + ServerContext, + base_join, + get_and_clamp_float, + get_and_clamp_int, + get_from_list, + get_from_map, + get_not_empty, +) diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index bbd3ca9e..a6349eb6 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,42 +1,13 @@ -from .base import ( - ChainPipeline, - PipelineStage, - StageCallback, - StageParams, -) -from .blend_img2img import ( - blend_img2img, -) -from .blend_inpaint import ( - blend_inpaint, -) -from .correct_gfpgan import ( - correct_gfpgan, -) -from .persist_disk import ( - persist_disk, -) -from .persist_s3 import ( - persist_s3, -) -from .reduce_crop import ( - reduce_crop, -) -from .reduce_thumbnail import ( - reduce_thumbnail, -) -from .source_noise import ( - source_noise, -) -from .source_txt2img import ( - source_txt2img, -) -from .upscale_outpaint import ( - upscale_outpaint, -) -from .upscale_resrgan import ( - upscale_resrgan, -) -from .upscale_stable_diffusion import ( - upscale_stable_diffusion, -) \ No newline at end of file +from .base import ChainPipeline, PipelineStage, StageCallback, StageParams +from .blend_img2img import blend_img2img +from .blend_inpaint import blend_inpaint +from .correct_gfpgan import correct_gfpgan +from .persist_disk import persist_disk +from .persist_s3 import persist_s3 +from .reduce_crop import reduce_crop +from .reduce_thumbnail import reduce_thumbnail +from .source_noise import source_noise +from .source_txt2img import source_txt2img +from .upscale_outpaint import upscale_outpaint +from .upscale_resrgan import upscale_resrgan +from .upscale_stable_diffusion import upscale_stable_diffusion diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 93f48c12..8e7a6487 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -1,26 +1,15 @@ from datetime import timedelta from logging import getLogger -from PIL import Image from time import monotonic from typing import Any, List, Optional, Protocol, Tuple -from ..device_pool import ( - JobContext, -) -from ..params import ( - ImageParams, - StageParams, -) -from ..output import ( - save_image, -) -from ..utils import ( - is_debug, - ServerContext, -) -from .utils import ( - process_tile_grid, -) +from PIL import Image + +from ..device_pool import JobContext +from ..output import save_image +from ..params import ImageParams, StageParams +from ..utils import ServerContext, is_debug +from .utils import process_tile_grid logger = getLogger(__name__) @@ -42,33 +31,43 @@ PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]] class ChainPipeline: - ''' + """ Run many stages in series, passing the image results from each to the next, and processing tiles as needed. - ''' + """ def __init__( self, stages: List[PipelineStage] = [], ): - ''' + """ Create a new pipeline that will run the given stages. - ''' + """ self.stages = list(stages) def append(self, stage: PipelineStage): - ''' + """ Append an additional stage to this pipeline. - ''' + """ self.stages.append(stage) - def __call__(self, job: JobContext, server: ServerContext, params: ImageParams, source: Image.Image, **pipeline_kwargs) -> Image.Image: - ''' + def __call__( + self, + job: JobContext, + server: ServerContext, + params: ImageParams, + source: Image.Image, + **pipeline_kwargs + ) -> Image.Image: + """ TODO: handle List[Image] outputs - ''' + """ start = monotonic() - logger.info('running pipeline on source image with dimensions %sx%s', - source.width, source.height) + logger.info( + "running pipeline on source image with dimensions %sx%s", + source.width, + source.height, + ) image = source for stage_pipe, stage_params, stage_kwargs in self.stages: @@ -76,37 +75,51 @@ class ChainPipeline: kwargs = stage_kwargs or {} kwargs = {**pipeline_kwargs, **kwargs} - logger.info('running stage %s on image with dimensions %sx%s, %s', - name, image.width, image.height, kwargs.keys()) + logger.info( + "running stage %s on image with dimensions %sx%s, %s", + name, + image.width, + image.height, + kwargs.keys(), + ) - if image.width > stage_params.tile_size or image.height > stage_params.tile_size: - logger.info('image larger than tile size of %s, tiling stage', - stage_params.tile_size) + if ( + image.width > stage_params.tile_size + or image.height > stage_params.tile_size + ): + logger.info( + "image larger than tile size of %s, tiling stage", + stage_params.tile_size, + ) def stage_tile(tile: Image.Image, _dims) -> Image.Image: - tile = stage_pipe(job, server, stage_params, params, tile, - **kwargs) + tile = stage_pipe(job, server, stage_params, params, tile, **kwargs) if is_debug(): - save_image(server, 'last-tile.png', tile) + save_image(server, "last-tile.png", tile) return tile image = process_tile_grid( - image, stage_params.tile_size, stage_params.outscale, [stage_tile]) + image, stage_params.tile_size, stage_params.outscale, [stage_tile] + ) else: - logger.info('image within tile size, running stage') - image = stage_pipe(job, server, stage_params, params, image, - **kwargs) + logger.info("image within tile size, running stage") + image = stage_pipe(job, server, stage_params, params, image, **kwargs) - logger.info('finished stage %s, result size: %sx%s', - name, image.width, image.height) + logger.info( + "finished stage %s, result size: %sx%s", name, image.width, image.height + ) if is_debug(): - save_image(server, 'last-stage.png', image) + save_image(server, "last-stage.png", image) end = monotonic() duration = timedelta(seconds=(end - start)) - logger.info('finished pipeline in %s, result size: %sx%s', - duration, image.width, image.height) + logger.info( + "finished pipeline in %s, result size: %sx%s", + duration, + image.width, + image.height, + ) return image diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index ae6b4dcc..5d0da615 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -1,24 +1,13 @@ -from diffusers import ( - OnnxStableDiffusionImg2ImgPipeline, -) from logging import getLogger -from PIL import Image - -from ..device_pool import ( - JobContext, -) -from ..diffusion.load import ( - load_pipeline, -) -from ..params import ( - ImageParams, - StageParams, -) -from ..utils import ( - ServerContext, -) import numpy as np +from diffusers import OnnxStableDiffusionImg2ImgPipeline +from PIL import Image + +from ..device_pool import JobContext +from ..diffusion.load import load_pipeline +from ..params import ImageParams, StageParams +from ..utils import ServerContext logger = getLogger(__name__) @@ -35,10 +24,14 @@ def blend_img2img( **kwargs, ) -> Image.Image: prompt = prompt or params.prompt - logger.info('generating image using img2img, %s steps: %s', params.steps, prompt) + logger.info("generating image using img2img, %s steps: %s", params.steps, prompt) - pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, - params.model, params.scheduler, job.get_device()) + pipe = load_pipeline( + OnnxStableDiffusionImg2ImgPipeline, + params.model, + params.scheduler, + job.get_device(), + ) rng = np.random.RandomState(params.seed) @@ -53,6 +46,5 @@ def blend_img2img( ) output = result.images[0] - logger.info('final output image size: %sx%s', output.width, output.height) + logger.info("final output image size: %sx%s", output.width, output.height) return output - diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 3eaccc0b..d59459c3 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -1,41 +1,17 @@ -from diffusers import ( - OnnxStableDiffusionInpaintPipeline, -) from logging import getLogger -from PIL import Image from typing import Callable, Tuple -from ..device_pool import ( - JobContext, -) -from ..diffusion.load import ( - get_latents_from_seed, - load_pipeline, -) -from ..image import ( - expand_image, - mask_filter_none, - noise_source_histogram, -) -from ..params import ( - Border, - ImageParams, - Size, - SizeChart, - StageParams, -) -from ..output import ( - save_image, -) -from ..utils import ( - is_debug, - ServerContext, -) -from .utils import ( - process_tile_grid, -) - import numpy as np +from diffusers import OnnxStableDiffusionInpaintPipeline +from PIL import Image + +from ..device_pool import JobContext +from ..diffusion.load import get_latents_from_seed, load_pipeline +from ..image import expand_image, mask_filter_none, noise_source_histogram +from ..output import save_image +from ..params import Border, ImageParams, Size, SizeChart, StageParams +from ..utils import ServerContext, is_debug +from .utils import process_tile_grid logger = getLogger(__name__) @@ -49,16 +25,16 @@ def blend_inpaint( *, expand: Border, mask_image: Image.Image = None, - fill_color: str = 'white', + fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, **kwargs, ) -> Image.Image: - logger.info('upscaling image by expanding borders', expand) + logger.info("upscaling image by expanding borders", expand) if mask_image is None: # if no mask was provided, keep the full source image - mask_image = Image.new('RGB', source_image.size, 'black') + mask_image = Image.new("RGB", source_image.size, "black") source_image, mask_image, noise_image, _full_dims = expand_image( source_image, @@ -66,12 +42,13 @@ def blend_inpaint( expand, fill=fill_color, noise_source=noise_source, - mask_filter=mask_filter) + mask_filter=mask_filter, + ) if is_debug(): - save_image(server, 'last-source.png', source_image) - save_image(server, 'last-mask.png', mask_image) - save_image(server, 'last-noise.png', noise_image) + save_image(server, "last-source.png", source_image) + save_image(server, "last-mask.png", mask_image) + save_image(server, "last-noise.png", noise_image) def outpaint(image: Image.Image, dims: Tuple[int, int, int]): left, top, tile = dims @@ -79,11 +56,15 @@ def blend_inpaint( mask = mask_image.crop((left, top, left + tile, top + tile)) if is_debug(): - save_image(server, 'tile-source.png', image) - save_image(server, 'tile-mask.png', mask) + save_image(server, "tile-source.png", image) + save_image(server, "tile-mask.png", mask) - pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, - params.model, params.scheduler, job.get_device()) + pipe = load_pipeline( + OnnxStableDiffusionInpaintPipeline, + params.model, + params.scheduler, + job.get_device(), + ) latents = get_latents_from_seed(params.seed, size) rng = np.random.RandomState(params.seed) @@ -104,5 +85,5 @@ def blend_inpaint( output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) - logger.info('final output image size', output.size) + logger.info("final output image size", output.size) return output diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 5eae5b12..7104255f 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -1,30 +1,53 @@ +from logging import getLogger + +import torch from basicsr.utils import img2tensor, tensor2img from basicsr.utils.download_util import load_file_from_url from facexlib.utils.face_restoration_helper import FaceRestoreHelper -from logging import getLogger from PIL import Image from torchvision.transforms.functional import normalize -import torch +from ..device_pool import JobContext +from ..params import ImageParams, StageParams +from ..utils import ServerContext logger = getLogger(__name__) -pretrain_model_url = { - 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', -} +pretrain_model_url = ( + "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" +) -device = 'cpu' +device = "cpu" upscale = 2 -def correct_codeformer(image: Image.Image) -> Image.Image: + +def correct_codeformer( + job: JobContext, + server: ServerContext, + stage: StageParams, + params: ImageParams, + source_image: Image.Image, + **kwargs, +) -> Image.Image: + model = "TODO" + # ------------------ set up CodeFormer restorer ------------------- - net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, - connect_list=['32', '64', '128', '256']).to(device) + net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], + ).to(device) # ckpt_path = 'weights/CodeFormer/codeformer.pth' - ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], - model_dir='weights/CodeFormer', progress=True, file_name=None) - checkpoint = torch.load(ckpt_path)['params_ema'] + ckpt_path = load_file_from_url( + url=pretrain_model_url, + model_dir="weights/CodeFormer", + progress=True, + file_name=None, + ) + checkpoint = torch.load(ckpt_path)["params_ema"] net.load_state_dict(checkpoint) net.eval() @@ -36,22 +59,24 @@ def correct_codeformer(image: Image.Image) -> Image.Image: upscale, face_size=512, crop_ratio=(1, 1), - det_model = args.detection_model, - save_ext='png', + det_model=model, + save_ext="png", use_parse=True, - device=device) + device=device, + ) # get face landmarks for each face num_det_faces = face_helper.get_face_landmarks_5( - only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) - logger.info('detect %s faces', num_det_faces) + only_center_face=False, resize=640, eye_dist_threshold=5 + ) + logger.info("detect %s faces", num_det_faces) # align and warp each face face_helper.align_warp_face() # face restoration for each cropped face for idx, cropped_face in enumerate(face_helper.cropped_faces): # prepare data - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) cropped_face_t = cropped_face_t.unsqueeze(0).to(device) @@ -62,10 +87,10 @@ def correct_codeformer(image: Image.Image) -> Image.Image: del output torch.cuda.empty_cache() except Exception as error: - logger.error('Failed inference for CodeFormer: %s', error) + logger.error("Failed inference for CodeFormer: %s", error) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - restored_face = restored_face.astype('uint8') + restored_face = restored_face.astype("uint8") face_helper.add_restored_face(restored_face, cropped_face) # upsample the background @@ -75,13 +100,16 @@ def correct_codeformer(image: Image.Image) -> Image.Image: else: bg_img = None - # paste_back face_helper.get_inverse_affine(None) # paste each restored face to the input image if face_upsampler is not None: - restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler) + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=bg_img, draw_box=False, face_upsampler=face_upsampler + ) else: - restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=False) + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=bg_img, draw_box=False + ) - return restored_img \ No newline at end of file + return restored_img diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 6dcbfbea..ffd55181 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -1,27 +1,16 @@ -from gfpgan import GFPGANer from logging import getLogger from os import path -from PIL import Image -from realesrgan import RealESRGANer from typing import Optional -from ..device_pool import ( - JobContext, -) -from ..params import ( - ImageParams, - StageParams, - UpscaleParams, -) -from ..utils import ( - run_gc, - ServerContext, -) -from .upscale_resrgan import ( - load_resrgan, -) - import numpy as np +from gfpgan import GFPGANer +from PIL import Image +from realesrgan import RealESRGANer + +from ..device_pool import JobContext +from ..params import ImageParams, StageParams, UpscaleParams +from ..utils import ServerContext, run_gc +from .upscale_resrgan import load_resrgan logger = getLogger(__name__) @@ -30,7 +19,9 @@ last_pipeline_instance = None last_pipeline_params = None -def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None): +def load_gfpgan( + ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[RealESRGANer] = None +): global last_pipeline_instance global last_pipeline_params @@ -38,22 +29,22 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, upsampler: Optional[ bg_upscale = upscale.rescale(upscale.outscale) upsampler = load_resrgan(ctx, bg_upscale) - face_path = path.join(ctx.model_path, '%s.pth' % - (upscale.correction_model)) + face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model)) - if last_pipeline_instance != None and face_path == last_pipeline_params: - logger.info('reusing existing GFPGAN pipeline') + if last_pipeline_instance is not None and face_path == last_pipeline_params: + logger.info("reusing existing GFPGAN pipeline") return last_pipeline_instance - logger.debug('loading GFPGAN model from %s', face_path) + logger.debug("loading GFPGAN model from %s", face_path) # TODO: find a way to pass the ONNX model to underlying architectures gfpgan = GFPGANer( model_path=face_path, upscale=upscale.outscale, - arch='clean', + arch="clean", channel_multiplier=2, - bg_upsampler=upsampler) + bg_upsampler=upsampler, + ) last_pipeline_instance = gfpgan last_pipeline_params = face_path @@ -74,15 +65,20 @@ def correct_gfpgan( **kwargs, ) -> Image.Image: if upscale.correction_model is None: - logger.warn('no face model given, skipping') + logger.warn("no face model given, skipping") return source_image - logger.info('correcting faces with GFPGAN model: %s', upscale.correction_model) + logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) gfpgan = load_gfpgan(server, upscale, upsampler=upsampler) output = np.array(source_image) _, _, output = gfpgan.enhance( - output, has_aligned=False, only_center_face=False, paste_back=True, weight=upscale.face_strength) - output = Image.fromarray(output, 'RGB') + output, + has_aligned=False, + only_center_face=False, + paste_back=True, + weight=upscale.face_strength, + ) + output = Image.fromarray(output, "RGB") return output diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index b9495352..abef41fb 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -1,26 +1,18 @@ from logging import getLogger + from PIL import Image -from ..device_pool import ( - JobContext, -) -from ..params import ( - ImageParams, - StageParams, -) -from ..output import ( - save_image, -) -from ..utils import ( - ServerContext, -) +from ..device_pool import JobContext +from ..output import save_image +from ..params import ImageParams, StageParams +from ..utils import ServerContext logger = getLogger(__name__) def persist_disk( _job: JobContext, - ctx: ServerContext, + server: ServerContext, _stage: StageParams, _params: ImageParams, source_image: Image.Image, @@ -28,6 +20,6 @@ def persist_disk( output: str, **kwargs, ) -> Image.Image: - dest = save_image(ctx, output, source_image) - logger.info('saved image to %s', dest) + dest = save_image(server, output, source_image) + logger.info("saved image to %s", dest) return source_image diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 1fd966fb..31be69f3 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -1,26 +1,19 @@ -from boto3 import ( - Session, -) from io import BytesIO from logging import getLogger + +from boto3 import Session from PIL import Image -from ..device_pool import ( - JobContext, -) -from ..params import ( - ImageParams, - StageParams, -) -from ..utils import ( - ServerContext, -) +from ..device_pool import JobContext +from ..params import ImageParams, StageParams +from ..utils import ServerContext logger = getLogger(__name__) def persist_s3( - ctx: ServerContext, + _job: JobContext, + server: ServerContext, _stage: StageParams, _params: ImageParams, source_image: Image.Image, @@ -32,16 +25,16 @@ def persist_s3( **kwargs, ) -> Image.Image: session = Session(profile_name=profile_name) - s3 = session.client('s3', endpoint_url=endpoint_url) + s3 = session.client("s3", endpoint_url=endpoint_url) data = BytesIO() - source_image.save(data, format=ctx.image_format) + source_image.save(data, format=server.image_format) data.seek(0) try: s3.upload_fileobj(data, bucket, output) - logger.info('saved image to %s/%s', bucket, output) + logger.info("saved image to %s/%s", bucket, output) except Exception as err: - logger.error('error saving image to S3: %s', err) + logger.error("error saving image to S3: %s", err) return source_image diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index ca21ae02..e8f70f7d 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -1,23 +1,17 @@ from logging import getLogger + from PIL import Image -from ..device_pool import ( - JobContext, -) -from ..params import ( - ImageParams, - Size, - StageParams, -) -from ..utils import ( - ServerContext, -) +from ..device_pool import JobContext +from ..params import ImageParams, Size, StageParams +from ..utils import ServerContext logger = getLogger(__name__) def reduce_crop( - ctx: ServerContext, + _job: JobContext, + _server: ServerContext, _stage: StageParams, _params: ImageParams, source_image: Image.Image, @@ -26,8 +20,6 @@ def reduce_crop( size: Size, **kwargs, ) -> Image.Image: - image = source_image.crop( - (origin.width, origin.height, size.width, size.height)) - logger.info('created thumbnail with dimensions: %sx%s', - image.width, image.height) + image = source_image.crop((origin.width, origin.height, size.width, size.height)) + logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) return image diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index 02498daa..b5b25130 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -1,23 +1,17 @@ from logging import getLogger + from PIL import Image -from ..device_pool import ( - JobContext, -) -from ..params import ( - ImageParams, - Size, - StageParams, -) -from ..utils import ( - ServerContext, -) +from ..device_pool import JobContext +from ..params import ImageParams, Size, StageParams +from ..utils import ServerContext logger = getLogger(__name__) def reduce_thumbnail( - ctx: ServerContext, + _job: JobContext, + _server: ServerContext, _stage: StageParams, _params: ImageParams, source_image: Image.Image, @@ -26,6 +20,5 @@ def reduce_thumbnail( **kwargs, ) -> Image.Image: image = source_image.thumbnail((size.width, size.height)) - logger.info('created thumbnail with dimensions: %sx%s', - image.width, image.height) + logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) return image diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 8bd360dc..9b01ccb6 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -1,26 +1,19 @@ from logging import getLogger -from PIL import Image from typing import Callable -from ..device_pool import ( - JobContext, -) -from ..params import ( - ImageParams, - Size, - StageParams, -) -from ..utils import ( - ServerContext, -) +from PIL import Image +from ..device_pool import JobContext +from ..params import ImageParams, Size, StageParams +from ..utils import ServerContext logger = getLogger(__name__) def source_noise( - ctx: ServerContext, - stage: StageParams, + _job: JobContext, + _server: ServerContext, + _stage: StageParams, params: ImageParams, source_image: Image.Image, *, @@ -28,14 +21,12 @@ def source_noise( noise_source: Callable, **kwargs, ) -> Image.Image: - prompt = prompt or params.prompt - logger.info('generating image from noise source') + logger.info("generating image from noise source") if source_image is not None: - logger.warn( - 'a source image was passed to a noise stage, but will be discarded') + logger.warn("a source image was passed to a noise stage, but will be discarded") output = noise_source(source_image, (size.width, size.height), (0, 0)) - logger.info('final output image size: %sx%s', output.width, output.height) + logger.info("final output image size: %sx%s", output.width, output.height) return output diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index e72fba37..6122999e 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -1,26 +1,13 @@ -from diffusers import ( - OnnxStableDiffusionPipeline, -) from logging import getLogger -from PIL import Image - -from ..device_pool import ( - JobContext, -) -from ..diffusion.load import ( - get_latents_from_seed, - load_pipeline, -) -from ..params import ( - ImageParams, - Size, - StageParams, -) -from ..utils import ( - ServerContext, -) import numpy as np +from diffusers import OnnxStableDiffusionPipeline +from PIL import Image + +from ..device_pool import JobContext +from ..diffusion.load import get_latents_from_seed, load_pipeline +from ..params import ImageParams, Size, StageParams +from ..utils import ServerContext logger = getLogger(__name__) @@ -37,13 +24,16 @@ def source_txt2img( **kwargs, ) -> Image.Image: prompt = prompt or params.prompt - logger.info('generating image using txt2img, %s steps: %s', params.steps, prompt) + logger.info("generating image using txt2img, %s steps: %s", params.steps, prompt) if source_image is not None: - logger.warn('a source image was passed to a txt2img stage, but will be discarded') + logger.warn( + "a source image was passed to a txt2img stage, but will be discarded" + ) - pipe = load_pipeline(OnnxStableDiffusionPipeline, - params.model, params.scheduler, job.get_device()) + pipe = load_pipeline( + OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device() + ) latents = get_latents_from_seed(params.seed, size) rng = np.random.RandomState(params.seed) @@ -60,5 +50,5 @@ def source_txt2img( ) output = result.images[0] - logger.info('final output image size: %sx%s', output.width, output.height) + logger.info("final output image size: %sx%s", output.width, output.height) return output diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index dae724e2..a61d2ab5 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -1,43 +1,17 @@ -from diffusers import ( - OnnxStableDiffusionInpaintPipeline, -) from logging import getLogger -from PIL import Image, ImageDraw from typing import Callable, Tuple -from ..device_pool import ( - JobContext, -) -from ..diffusion.load import ( - get_latents_from_seed, - get_tile_latents, - load_pipeline, -) -from ..image import ( - expand_image, - mask_filter_none, - noise_source_histogram, -) -from ..params import ( - Border, - ImageParams, - Size, - SizeChart, - StageParams, -) -from ..output import ( - save_image, -) -from ..utils import ( - base_join, - is_debug, - ServerContext, -) -from .utils import ( - process_tile_spiral, -) - import numpy as np +from diffusers import OnnxStableDiffusionInpaintPipeline +from PIL import Image, ImageDraw + +from ..device_pool import JobContext +from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline +from ..image import expand_image, mask_filter_none, noise_source_histogram +from ..output import save_image +from ..params import Border, ImageParams, Size, SizeChart, StageParams +from ..utils import ServerContext, is_debug +from .utils import process_tile_spiral logger = getLogger(__name__) @@ -52,17 +26,17 @@ def upscale_outpaint( border: Border, prompt: str = None, mask_image: Image.Image = None, - fill_color: str = 'white', + fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, **kwargs, ) -> Image.Image: prompt = prompt or params.prompt - logger.info('upscaling image by expanding borders: %s', border) + logger.info("upscaling image by expanding borders: %s", border) if mask_image is None: # if no mask was provided, keep the full source image - mask_image = Image.new('RGB', source_image.size, 'black') + mask_image = Image.new("RGB", source_image.size, "black") source_image, mask_image, noise_image, full_dims = expand_image( source_image, @@ -70,16 +44,17 @@ def upscale_outpaint( border, fill=fill_color, noise_source=noise_source, - mask_filter=mask_filter) + mask_filter=mask_filter, + ) draw_mask = ImageDraw.Draw(mask_image) full_size = Size(*full_dims) full_latents = get_latents_from_seed(params.seed, full_size) if is_debug(): - save_image(server, 'last-source.png', source_image) - save_image(server, 'last-mask.png', mask_image) - save_image(server, 'last-noise.png', noise_image) + save_image(server, "last-source.png", source_image) + save_image(server, "last-mask.png", mask_image) + save_image(server, "last-noise.png", noise_image) def outpaint(image: Image.Image, dims: Tuple[int, int, int]): left, top, tile = dims @@ -87,11 +62,15 @@ def upscale_outpaint( mask = mask_image.crop((left, top, left + tile, top + tile)) if is_debug(): - save_image(server, 'tile-source.png', image) - save_image(server, 'tile-mask.png', mask) + save_image(server, "tile-source.png", image) + save_image(server, "tile-mask.png", mask) - pipe = load_pipeline(OnnxStableDiffusionInpaintPipeline, - params.model, params.scheduler, job.get_device()) + pipe = load_pipeline( + OnnxStableDiffusionInpaintPipeline, + params.model, + params.scheduler, + job.get_device(), + ) latents = get_tile_latents(full_latents, dims) rng = np.random.RandomState(params.seed) @@ -110,10 +89,10 @@ def upscale_outpaint( ) # once part of the image has been drawn, keep it - draw_mask.rectangle((left, top, left + tile, top + tile), fill='black') + draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") return result.images[0] output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint]) - logger.info('final output image size: %sx%s', output.width, output.height) + logger.info("final output image size: %sx%s", output.width, output.height) return output diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 885c1d0c..e22522be 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -1,27 +1,15 @@ -from basicsr.archs.rrdbnet_arch import RRDBNet from logging import getLogger from os import path + +import numpy as np +from basicsr.archs.rrdbnet_arch import RRDBNet from PIL import Image from realesrgan import RealESRGANer -from ..device_pool import ( - JobContext, -) -from ..onnx import ( - OnnxNet, -) -from ..params import ( - DeviceParams, - ImageParams, - StageParams, - UpscaleParams, -) -from ..utils import ( - run_gc, - ServerContext, -) - -import numpy as np +from ..device_pool import JobContext +from ..onnx import OnnxNet +from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..utils import ServerContext, run_gc logger = getLogger(__name__) @@ -29,39 +17,50 @@ last_pipeline_instance = None last_pipeline_params = (None, None) -def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0): +def load_resrgan( + ctx: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 +): global last_pipeline_instance global last_pipeline_params - model_file = '%s.%s' % (params.upscale_model, params.format) + model_file = "%s.%s" % (params.upscale_model, params.format) model_path = path.join(ctx.model_path, model_file) if not path.isfile(model_path): - raise Exception('Real ESRGAN model not found at %s' % model_path) + raise Exception("Real ESRGAN model not found at %s" % model_path) cache_params = (model_path, params.format) - if last_pipeline_instance != None and cache_params == last_pipeline_params: - logger.info('reusing existing Real ESRGAN pipeline') + if last_pipeline_instance is not None and cache_params == last_pipeline_params: + logger.info("reusing existing Real ESRGAN pipeline") return last_pipeline_instance # use ONNX acceleration, if available - if params.format == 'onnx': - model = OnnxNet(ctx, model_file, provider=device.provider, provider_options=device.options) - elif params.format == 'pth': - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, - num_block=23, num_grow_ch=32, scale=params.scale) - raise Exception('unknown platform %s' % params.format) + if params.format == "onnx": + model = OnnxNet( + ctx, model_file, provider=device.provider, provider_options=device.options + ) + elif params.format == "pth": + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=params.scale, + ) + raise Exception("unknown platform %s" % params.format) dni_weight = None - if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1: + if params.upscale_model == "realesr-general-x4v3" and params.denoise != 1: wdn_model_path = model_path.replace( - 'realesr-general-x4v3', 'realesr-general-wdn-x4v3') + "realesr-general-x4v3", "realesr-general-wdn-x4v3" + ) model_path = [model_path, wdn_model_path] dni_weight = [params.denoise, 1 - params.denoise] - logger.debug('loading Real ESRGAN upscale model from %s', model_path) + logger.debug("loading Real ESRGAN upscale model from %s", model_path) # TODO: shouldn't need the PTH file - model_path_pth = path.join(ctx.model_path, '%s.pth' % params.upscale_model) + model_path_pth = path.join(ctx.model_path, "%s.pth" % params.upscale_model) upsampler = RealESRGANer( scale=params.scale, model_path=model_path_pth, @@ -70,7 +69,8 @@ def load_resrgan(ctx: ServerContext, params: UpscaleParams, device: DeviceParams tile=tile, tile_pad=params.tile_pad, pre_pad=params.pre_pad, - half=params.half) + half=params.half, + ) last_pipeline_instance = upsampler last_pipeline_params = cache_params @@ -89,13 +89,13 @@ def upscale_resrgan( upscale: UpscaleParams, **kwargs, ) -> Image.Image: - logger.info('upscaling image with Real ESRGAN: x%s', upscale.scale) + logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) output = np.array(source_image) upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size) output, _ = upsampler.enhance(output, outscale=upscale.outscale) - output = Image.fromarray(output, 'RGB') - logger.info('final output image size: %sx%s', output.width, output.height) + output = Image.fromarray(output, "RGB") + logger.info("final output image size: %sx%s", output.width, output.height) return output diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index c9d89b86..a7f37d4c 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -1,28 +1,16 @@ -from diffusers import ( - StableDiffusionUpscalePipeline, -) from logging import getLogger from os import path + +import torch +from diffusers import StableDiffusionUpscalePipeline from PIL import Image -from ..device_pool import ( - JobContext, -) +from ..device_pool import JobContext from ..diffusion.pipeline_onnx_stable_diffusion_upscale import ( OnnxStableDiffusionUpscalePipeline, ) -from ..params import ( - DeviceParams, - ImageParams, - StageParams, - UpscaleParams, -) -from ..utils import ( - run_gc, - ServerContext, -) - -import torch +from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..utils import ServerContext, run_gc logger = getLogger(__name__) @@ -31,23 +19,37 @@ last_pipeline_instance = None last_pipeline_params = (None, None) -def load_stable_diffusion(ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams): +def load_stable_diffusion( + ctx: ServerContext, upscale: UpscaleParams, device: DeviceParams +): global last_pipeline_instance global last_pipeline_params model_path = path.join(ctx.model_path, upscale.upscale_model) cache_params = (model_path, upscale.format) - if last_pipeline_instance != None and cache_params == last_pipeline_params: - logger.info('reusing existing Stable Diffusion upscale pipeline') + if last_pipeline_instance is not None and cache_params == last_pipeline_params: + logger.info("reusing existing Stable Diffusion upscale pipeline") return last_pipeline_instance - if upscale.format == 'onnx': - logger.debug('loading Stable Diffusion upscale ONNX model from %s, using provider %s', model_path, device.provider) - pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider, provider_options=device.options) + if upscale.format == "onnx": + logger.debug( + "loading Stable Diffusion upscale ONNX model from %s, using provider %s", + model_path, + device.provider, + ) + pipeline = OnnxStableDiffusionUpscalePipeline.from_pretrained( + model_path, provider=device.provider, provider_options=device.options + ) else: - logger.debug('loading Stable Diffusion upscale model from %s, using provider %s', model_path, device.provider) - pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_path, provider=device.provider) + logger.debug( + "loading Stable Diffusion upscale model from %s, using provider %s", + model_path, + device.provider, + ) + pipeline = StableDiffusionUpscalePipeline.from_pretrained( + model_path, provider=device.provider + ) last_pipeline_instance = pipeline last_pipeline_params = cache_params @@ -68,7 +70,7 @@ def upscale_stable_diffusion( **kwargs, ) -> Image.Image: prompt = prompt or params.prompt - logger.info('upscaling with Stable Diffusion, %s steps: %s', params.steps, prompt) + logger.info("upscaling with Stable Diffusion, %s steps: %s", params.steps, prompt) pipeline = load_stable_diffusion(server, upscale, job.get_device()) generator = torch.manual_seed(params.seed) diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index e4f2162d..0e112303 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -1,7 +1,8 @@ from logging import getLogger -from PIL import Image from typing import List, Protocol, Tuple +from PIL import Image + logger = getLogger(__name__) @@ -17,7 +18,7 @@ def process_tile_grid( filters: List[TileCallback], ) -> Image.Image: width, height = source.size - image = Image.new('RGB', (width * scale, height * scale)) + image = Image.new("RGB", (width * scale, height * scale)) tiles_x = width // tile tiles_y = height // tile @@ -28,7 +29,7 @@ def process_tile_grid( idx = (y * tiles_x) + x left = x * tile top = y * tile - logger.info('processing tile %s of %s, %s.%s', idx + 1, total, y, x) + logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x) tile_image = source.crop((left, top, left + tile, top + tile)) for filter in filters: @@ -47,10 +48,10 @@ def process_tile_spiral( overlap: float = 0.5, ) -> Image.Image: if scale != 1: - raise Exception('unsupported scale') + raise Exception("unsupported scale") width, height = source.size - image = Image.new('RGB', (width * scale, height * scale)) + image = Image.new("RGB", (width * scale, height * scale)) image.paste(source, (0, 0, width, height)) center_x = (width // 2) - (tile // 2) @@ -76,7 +77,7 @@ def process_tile_spiral( top = center_y + int(top) counter += 1 - logger.info('processing tile %s of %s, %sx%s', counter, len(tiles), left, top) + logger.info("processing tile %s of %s, %sx%s", counter, len(tiles), left, top) # TODO: only valid for scale == 1, resize source for others tile_image = image.crop((left, top, left + tile, top + tile)) diff --git a/api/onnx_web/convert.py b/api/onnx_web/convert.py index 4e585b40..eb9c6d97 100644 --- a/api/onnx_web/convert.py +++ b/api/onnx_web/convert.py @@ -1,5 +1,14 @@ -from . import logging +import warnings from argparse import ArgumentParser +from json import loads +from logging import getLogger +from os import environ, makedirs, mkdir, path +from pathlib import Path +from shutil import copyfile, rmtree +from sys import exit +from typing import Dict, List, Optional, Tuple + +import torch from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.utils.download_util import load_file_from_url from diffusers import ( @@ -8,25 +17,20 @@ from diffusers import ( StableDiffusionPipeline, StableDiffusionUpscalePipeline, ) -from json import loads -from logging import getLogger from onnx import load, save_model -from os import environ, makedirs, mkdir, path -from pathlib import Path -from shutil import copyfile, rmtree -from sys import exit from torch.onnx import export -from typing import Dict, List, Optional, Tuple -import torch -import warnings +from . import logging # suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75 warnings.filterwarnings( - 'ignore', '.*The shape inference of prim::Constant type is missing.*') -warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*') + "ignore", ".*The shape inference of prim::Constant type is missing.*" +) +warnings.filterwarnings("ignore", ".*Only steps=1 can be constant folded.*") warnings.filterwarnings( - 'ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*') + "ignore", + ".*Converting a tensor to a Python boolean might cause the trace to be incorrect.*", +) Models = Dict[str, List[Tuple[str, str, Optional[int]]]] @@ -35,74 +39,95 @@ logger = getLogger(__name__) # recommended models base_models: Models = { - 'diffusion': [ + "diffusion": [ # v1.x - ('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'), - ('stable-diffusion-onnx-v1-inpainting', - 'runwayml/stable-diffusion-inpainting'), + ("stable-diffusion-onnx-v1-5", "runwayml/stable-diffusion-v1-5"), + ("stable-diffusion-onnx-v1-inpainting", "runwayml/stable-diffusion-inpainting"), # v2.x - ('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'), - ('stable-diffusion-onnx-v2-inpainting', - 'stabilityai/stable-diffusion-2-inpainting'), + ("stable-diffusion-onnx-v2-1", "stabilityai/stable-diffusion-2-1"), + ( + "stable-diffusion-onnx-v2-inpainting", + "stabilityai/stable-diffusion-2-inpainting", + ), # TODO: should have its own converter - ('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'), + ("upscaling-stable-diffusion-x4", "stabilityai/stable-diffusion-x4-upscaler"), ], - 'correction': [ - ('correction-gfpgan-v1-3', - 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4), + "correction": [ + ( + "correction-gfpgan-v1-3", + "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", + 4, + ), ], - 'upscaling': [ - ('upscaling-real-esrgan-x2-plus', - 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 2), - ('upscaling-real-esrgan-x4-plus', - 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4), - ('upscaling-real-esrgan-x4-v3', - 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', 4), + "upscaling": [ + ( + "upscaling-real-esrgan-x2-plus", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + 2, + ), + ( + "upscaling-real-esrgan-x4-plus", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + 4, + ), + ( + "upscaling-real-esrgan-x4-v3", + "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + 4, + ), ], } -model_path = environ.get('ONNX_WEB_MODEL_PATH', - path.join('..', 'models')) -training_device = 'cuda' if torch.cuda.is_available() else 'cpu' +model_path = environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")) +training_device = "cuda" if torch.cuda.is_available() else "cpu" map_location = torch.device(training_device) @torch.no_grad() def convert_real_esrgan(name: str, url: str, scale: int, opset: int): - dest_path = path.join(model_path, name + '.pth') - dest_onnx = path.join(model_path, name + '.onnx') - logger.info('converting Real ESRGAN model: %s -> %s', name, dest_onnx) + dest_path = path.join(model_path, name + ".pth") + dest_onnx = path.join(model_path, name + ".onnx") + logger.info("converting Real ESRGAN model: %s -> %s", name, dest_onnx) if path.isfile(dest_onnx): - logger.info('ONNX model already exists, skipping.') + logger.info("ONNX model already exists, skipping.") return if not path.isfile(dest_path): - logger.info('PTH model not found, downloading...') + logger.info("PTH model not found, downloading...") download_path = load_file_from_url( - url=url, model_dir=dest_path + '-cache', progress=True, file_name=None) + url=url, model_dir=dest_path + "-cache", progress=True, file_name=None + ) copyfile(download_path, dest_path) - logger.info('loading and training model') - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, - num_block=23, num_grow_ch=32, scale=scale) + logger.info("loading and training model") + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=scale, + ) torch_model = torch.load(dest_path, map_location=map_location) - if 'params_ema' in torch_model: - model.load_state_dict(torch_model['params_ema']) + if "params_ema" in torch_model: + model.load_state_dict(torch_model["params_ema"]) else: - model.load_state_dict(torch_model['params'], strict=False) + model.load_state_dict(torch_model["params"], strict=False) model.to(training_device).train(False) model.eval() rng = torch.rand(1, 3, 64, 64, device=map_location) - input_names = ['data'] - output_names = ['output'] - dynamic_axes = {'data': {2: 'width', 3: 'height'}, - 'output': {2: 'width', 3: 'height'}} + input_names = ["data"] + output_names = ["output"] + dynamic_axes = { + "data": {2: "width", 3: "height"}, + "output": {2: "width", 3: "height"}, + } - logger.info('exporting ONNX model to %s', dest_onnx) + logger.info("exporting ONNX model to %s", dest_onnx) export( model, rng, @@ -111,48 +136,57 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int): output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset, - export_params=True + export_params=True, ) - logger.info('Real ESRGAN exported to ONNX successfully.') + logger.info("Real ESRGAN exported to ONNX successfully.") @torch.no_grad() def convert_gfpgan(name: str, url: str, scale: int, opset: int): - dest_path = path.join(model_path, name + '.pth') - dest_onnx = path.join(model_path, name + '.onnx') - logger.info('converting GFPGAN model: %s -> %s', name, dest_onnx) + dest_path = path.join(model_path, name + ".pth") + dest_onnx = path.join(model_path, name + ".onnx") + logger.info("converting GFPGAN model: %s -> %s", name, dest_onnx) if path.isfile(dest_onnx): - logger.info('ONNX model already exists, skipping.') + logger.info("ONNX model already exists, skipping.") return if not path.isfile(dest_path): - logger.info('PTH model not found, downloading...') + logger.info("PTH model not found, downloading...") download_path = load_file_from_url( - url=url, model_dir=dest_path + '-cache', progress=True, file_name=None) + url=url, model_dir=dest_path + "-cache", progress=True, file_name=None + ) copyfile(download_path, dest_path) - logger.info('loading and training model') - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, - num_block=23, num_grow_ch=32, scale=scale) + logger.info("loading and training model") + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=scale, + ) torch_model = torch.load(dest_path, map_location=map_location) # TODO: make sure strict=False is safe here - if 'params_ema' in torch_model: - model.load_state_dict(torch_model['params_ema'], strict=False) + if "params_ema" in torch_model: + model.load_state_dict(torch_model["params_ema"], strict=False) else: - model.load_state_dict(torch_model['params'], strict=False) + model.load_state_dict(torch_model["params"], strict=False) model.to(training_device).train(False) model.eval() rng = torch.rand(1, 3, 64, 64, device=map_location) - input_names = ['data'] - output_names = ['output'] - dynamic_axes = {'data': {2: 'width', 3: 'height'}, - 'output': {2: 'width', 3: 'height'}} + input_names = ["data"] + output_names = ["output"] + dynamic_axes = { + "data": {2: "width", 3: "height"}, + "output": {2: "width", 3: "height"}, + } - logger.info('exporting ONNX model to %s', dest_onnx) + logger.info("exporting ONNX model to %s", dest_onnx) export( model, rng, @@ -161,9 +195,9 @@ def convert_gfpgan(name: str, url: str, scale: int, opset: int): output_names=output_names, dynamic_axes=dynamic_axes, opset_version=opset, - export_params=True + export_params=True, ) - logger.info('GFPGAN exported to ONNX successfully.') + logger.info("GFPGAN exported to ONNX successfully.") def onnx_export( @@ -176,9 +210,9 @@ def onnx_export( opset, use_external_data_format=False, ): - ''' + """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py - ''' + """ output_path.parent.mkdir(parents=True, exist_ok=True) export( @@ -194,29 +228,33 @@ def onnx_export( @torch.no_grad() -def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False): - ''' +def convert_diffuser( + name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False +): + """ From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py - ''' + """ dtype = torch.float16 if half else torch.float32 dest_path = path.join(model_path, name) # diffusers go into a directory rather than .onnx file - logger.info('converting Diffusers model: %s -> %s/', name, dest_path) + logger.info("converting Diffusers model: %s -> %s/", name, dest_path) if single_vae: - logger.info('converting model with single VAE') + logger.info("converting model with single VAE") if path.isdir(dest_path): - logger.info('ONNX model already exists, skipping.') + logger.info("ONNX model already exists, skipping.") return - if half and training_device != 'cuda': + if half and training_device != "cuda": raise ValueError( - 'Half precision model export is only supported on GPUs with CUDA') + "Half precision model export is only supported on GPUs with CUDA" + ) pipeline = StableDiffusionPipeline.from_pretrained( - url, torch_dtype=dtype, use_auth_token=token).to(training_device) + url, torch_dtype=dtype, use_auth_token=token + ).to(training_device) output_path = Path(dest_path) # TEXT ENCODER @@ -232,8 +270,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si onnx_export( pipeline.text_encoder, # casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files - model_args=(text_input.input_ids.to( - device=training_device, dtype=torch.int32)), + model_args=(text_input.input_ids.to(device=training_device, dtype=torch.int32)), output_path=output_path / "text_encoder" / "model.onnx", ordered_input_names=["input_ids"], output_names=["last_hidden_state", "pooler_output"], @@ -244,7 +281,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si ) del pipeline.text_encoder - logger.debug('UNET config: %s', pipeline.unet.config) + logger.debug("UNET config: %s", pipeline.unet.config) # UNET if single_vae: @@ -262,10 +299,12 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si pipeline.unet, model_args=( torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to( - device=training_device, dtype=dtype), + device=training_device, dtype=dtype + ), torch.randn(2).to(device=training_device, dtype=dtype), torch.randn(2, num_tokens, text_hidden_size).to( - device=training_device, dtype=dtype), + device=training_device, dtype=dtype + ), unet_scale, ), output_path=unet_path, @@ -298,7 +337,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si del pipeline.unet if single_vae: - logger.debug('VAE config: %s', pipeline.vae.config) + logger.debug("VAE config: %s", pipeline.vae.config) # SINGLE VAE vae_only = pipeline.vae @@ -309,8 +348,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si onnx_export( vae_only, model_args=( - torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to( - device=training_device, dtype=dtype), + torch.randn( + 1, vae_latent_channels, unet_sample_size, unet_sample_size + ).to(device=training_device, dtype=dtype), False, ), output_path=output_path / "vae" / "model.onnx", @@ -328,12 +368,14 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si vae_sample_size = vae_encoder.config.sample_size # need to get the raw tensor output (sample) from the encoder vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode( - sample, return_dict)[0].sample() + sample, return_dict + )[0].sample() onnx_export( vae_encoder, model_args=( torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to( - device=training_device, dtype=dtype), + device=training_device, dtype=dtype + ), False, ), output_path=output_path / "vae_encoder" / "model.onnx", @@ -354,8 +396,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si onnx_export( vae_decoder, model_args=( - torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to( - device=training_device, dtype=dtype), + torch.randn( + 1, vae_latent_channels, unet_sample_size, unet_sample_size + ).to(device=training_device, dtype=dtype), False, ), output_path=output_path / "vae_decoder" / "model.onnx", @@ -385,7 +428,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si clip_image_size, ).to(device=training_device, dtype=dtype), torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to( - device=training_device, dtype=dtype), + device=training_device, dtype=dtype + ), ), output_path=output_path / "safety_checker" / "model.onnx", ordered_input_names=["clip_input", "images"], @@ -398,7 +442,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si ) del pipeline.safety_checker safety_checker = OnnxRuntimeModel.from_pretrained( - output_path / "safety_checker") + output_path / "safety_checker" + ) feature_extractor = pipeline.feature_extractor else: safety_checker = None @@ -406,10 +451,8 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si if single_vae: onnx_pipeline = StableDiffusionUpscalePipeline( - vae=OnnxRuntimeModel.from_pretrained( - output_path / "vae"), - text_encoder=OnnxRuntimeModel.from_pretrained( - output_path / "text_encoder"), + vae=OnnxRuntimeModel.from_pretrained(output_path / "vae"), + text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), tokenizer=pipeline.tokenizer, low_res_scheduler=pipeline.scheduler, unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), @@ -417,12 +460,9 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si ) else: onnx_pipeline = OnnxStableDiffusionPipeline( - vae_encoder=OnnxRuntimeModel.from_pretrained( - output_path / "vae_encoder"), - vae_decoder=OnnxRuntimeModel.from_pretrained( - output_path / "vae_decoder"), - text_encoder=OnnxRuntimeModel.from_pretrained( - output_path / "text_encoder"), + vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"), + vae_decoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_decoder"), + text_encoder=OnnxRuntimeModel.from_pretrained(output_path / "text_encoder"), tokenizer=pipeline.tokenizer, unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"), scheduler=pipeline.scheduler, @@ -431,7 +471,7 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si requires_safety_checker=safety_checker is not None, ) - logger.info('exporting ONNX model') + logger.info("exporting ONNX model") onnx_pipeline.save_pretrained(output_path) logger.info("ONNX pipeline saved to %s", output_path) @@ -445,90 +485,93 @@ def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, si ) else: _ = OnnxStableDiffusionPipeline.from_pretrained( - output_path, provider="CPUExecutionProvider") + output_path, provider="CPUExecutionProvider" + ) logger.info("ONNX pipeline is loadable") def load_models(args, models: Models): if args.diffusion: - for source in models.get('diffusion'): + for source in models.get("diffusion"): if source[0] in args.skip: - logger.info('Skipping model: %s', source[0]) + logger.info("Skipping model: %s", source[0]) else: - single_vae = 'upscaling' in source[0] - convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae) + single_vae = "upscaling" in source[0] + convert_diffuser( + *source, args.opset, args.half, args.token, single_vae=single_vae + ) if args.upscaling: - for source in models.get('upscaling'): + for source in models.get("upscaling"): if source[0] in args.skip: - logger.info('Skipping model: %s', source[0]) + logger.info("Skipping model: %s", source[0]) else: convert_real_esrgan(*source, args.opset) if args.correction: - for source in models.get('correction'): + for source in models.get("correction"): if source[0] in args.skip: - logger.info('Skipping model: %s', source[0]) + logger.info("Skipping model: %s", source[0]) else: convert_gfpgan(*source, args.opset) def main() -> int: parser = ArgumentParser( - prog='onnx-web model converter', - description='convert checkpoint models to ONNX') + prog="onnx-web model converter", description="convert checkpoint models to ONNX" + ) # model groups - parser.add_argument('--correction', action='store_true', default=False) - parser.add_argument('--diffusion', action='store_true', default=False) - parser.add_argument('--upscaling', action='store_true', default=False) + parser.add_argument("--correction", action="store_true", default=False) + parser.add_argument("--diffusion", action="store_true", default=False) + parser.add_argument("--upscaling", action="store_true", default=False) # extra models - parser.add_argument('--extras', nargs='*', type=str, default=[]) - parser.add_argument('--skip', nargs='*', type=str, default=[]) + parser.add_argument("--extras", nargs="*", type=str, default=[]) + parser.add_argument("--skip", nargs="*", type=str, default=[]) # export options parser.add_argument( - '--half', - action='store_true', + "--half", + action="store_true", default=False, - help='Export models for half precision, faster on some Nvidia cards.' + help="Export models for half precision, faster on some Nvidia cards.", ) parser.add_argument( - '--opset', + "--opset", default=14, type=int, help="The version of the ONNX operator set to use.", ) parser.add_argument( - '--token', + "--token", type=str, help="HuggingFace token with read permissions for downloading models.", ) args = parser.parse_args() - logger.info('CLI arguments: %s', args) + logger.info("CLI arguments: %s", args) if not path.exists(model_path): - logger.info('Model path does not existing, creating: %s', model_path) + logger.info("Model path does not existing, creating: %s", model_path) makedirs(model_path) - logger.info('Converting base models.') + logger.info("Converting base models.") load_models(args, base_models) for file in args.extras: - logger.info('Loading extra models from %s', file) + logger.info("Loading extra models from %s", file) try: - with open(file, 'r') as f: + with open(file, "r") as f: data = loads(f.read()) - logger.info('Converting extra models.') + logger.info("Converting extra models.") load_models(args, data) except Exception as err: - logger.error('Error converting extra models: %s', err) + logger.error("Error converting extra models: %s", err) return 0 -if __name__ == '__main__': +if __name__ == "__main__": exit(main()) diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/device_pool.py index 7c6b051f..a2ef1c0e 100644 --- a/api/onnx_web/device_pool.py +++ b/api/onnx_web/device_pool.py @@ -1,12 +1,10 @@ from collections import Counter -from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor +from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from logging import getLogger from multiprocessing import Value from typing import Any, Callable, List, Optional, Tuple, Union -from .params import ( - DeviceParams, -) +from .params import DeviceParams logger = getLogger(__name__) @@ -28,24 +26,24 @@ class JobContext: ): self.key = key self.devices = list(devices) - self.cancel = Value('B', cancel) - self.device_index = Value('i', device_index) - self.progress = Value('I', progress) + self.cancel = Value("B", cancel) + self.device_index = Value("i", device_index) + self.progress = Value("I", progress) def is_cancelled(self) -> bool: return self.cancel.value def get_device(self) -> DeviceParams: - ''' + """ Get the device assigned to this job. - ''' + """ with self.device_index.get_lock(): device_index = self.device_index.value if device_index < 0: - raise Exception('job has not been assigned to a device') + raise Exception("job has not been assigned to a device") else: device = self.devices[device_index] - logger.debug('job %s assigned to device %s', self.key, device) + logger.debug("job %s assigned to device %s", self.key, device) return device def get_progress(self) -> int: @@ -54,10 +52,9 @@ class JobContext: def get_progress_callback(self) -> Callable[..., None]: def on_progress(step: int, timestep: int, latents: Any): if self.is_cancelled(): - raise Exception('job has been cancelled') + raise Exception("job has been cancelled") else: - logger.debug('setting progress for job %s to %s', - self.key, step) + logger.debug("setting progress for job %s to %s", self.key, step) self.set_progress(step) return on_progress @@ -72,9 +69,9 @@ class JobContext: class Job: - ''' + """ Link a future to its context. - ''' + """ context: JobContext = None future: Future = None @@ -106,7 +103,11 @@ class DevicePoolExecutor: next_device: int = 0 pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None - def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None): + def __init__( + self, + devices: List[DeviceParams], + pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None, + ): self.devices = devices self.jobs = [] self.next_device = 0 @@ -114,19 +115,25 @@ class DevicePoolExecutor: device_count = len(devices) if pool is None: logger.info( - 'creating thread pool executor for %s devices: %s', device_count, [d.device for d in devices]) + "creating thread pool executor for %s devices: %s", + device_count, + [d.device for d in devices], + ) self.pool = ThreadPoolExecutor(device_count) else: - logger.info('using existing pool for %s devices: %s', - device_count, [d.device for d in devices]) + logger.info( + "using existing pool for %s devices: %s", + device_count, + [d.device for d in devices], + ) self.pool = pool def cancel(self, key: str) -> bool: - ''' + """ Cancel a job. If the job has not been started, this will cancel the future and never execute it. If the job has been started, it should be cancelled on the next progress callback. - ''' + """ for job in self.jobs: if job.key == key: if job.future.cancel(): @@ -144,7 +151,7 @@ class DevicePoolExecutor: progress = job.get_progress() return (done, progress) - logger.warn('checking status for unknown key: %s', key) + logger.warn("checking status for unknown key: %s", key) return (None, 0) def get_next_device(self): @@ -152,12 +159,14 @@ class DevicePoolExecutor: if len(self.jobs) == 0: return 0 - job_devices = [job.context.device_index.value for job in self.jobs if not job.future.done()] + job_devices = [ + job.context.device_index.value for job in self.jobs if not job.future.done() + ] job_counts = Counter(range(len(self.devices))) job_counts.update(job_devices) queued = job_counts.most_common() - logger.debug('jobs queued by device: %s', queued) + logger.debug("jobs queued by device: %s", queued) lowest_count = queued[-1][1] lowest_devices = [d[0] for d in queued if d[1] == lowest_count] @@ -170,7 +179,7 @@ class DevicePoolExecutor: def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None: device = self.get_next_device() - logger.info('assigning job %s to device %s', key, device) + logger.info("assigning job %s to device %s", key, device) context = JobContext(key, self.devices, device_index=device) future = self.pool.submit(fn, context, *args, **kwargs) @@ -180,11 +189,19 @@ class DevicePoolExecutor: def job_done(f: Future): try: f.result() - logger.info('job %s finished successfully', key) + logger.info("job %s finished successfully", key) except Exception as err: - logger.warn('job %s failed with an error: %s', key, err) + logger.warn("job %s failed with an error: %s", key, err) future.add_done_callback(job_done) def status(self) -> List[Tuple[str, int, bool, int]]: - return [(job.key, job.context.device_index.value, job.future.done(), job.get_progress()) for job in self.jobs] + return [ + ( + job.key, + job.context.device_index.value, + job.future.done(), + job.get_progress(), + ) + for job in self.jobs + ] diff --git a/api/onnx_web/diffusion/load.py b/api/onnx_web/diffusion/load.py index d45a7f6b..00c0bf2e 100644 --- a/api/onnx_web/diffusion/load.py +++ b/api/onnx_web/diffusion/load.py @@ -1,18 +1,11 @@ -from diffusers import ( - DiffusionPipeline, -) from logging import getLogger -from typing import Any, Optional, Tuple - -from ..params import ( - DeviceParams, - Size, -) -from ..utils import ( - run_gc, -) +from typing import Any, Tuple import numpy as np +from diffusers import DiffusionPipeline + +from ..params import DeviceParams, Size +from ..utils import run_gc logger = getLogger(__name__) @@ -25,17 +18,23 @@ latent_factor = 8 def get_latents_from_seed(seed: int, size: Size, batch: int = 1) -> np.ndarray: - ''' + """ From https://www.travelneil.com/stable-diffusion-updates.html - ''' - latents_shape = (batch, latent_channels, size.height // latent_factor, - size.width // latent_factor) + """ + latents_shape = ( + batch, + latent_channels, + size.height // latent_factor, + size.width // latent_factor, + ) rng = np.random.default_rng(seed) image_latents = rng.standard_normal(latents_shape).astype(np.float32) return image_latents -def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np.ndarray: +def get_tile_latents( + full_latents: np.ndarray, dims: Tuple[int, int, int] +) -> np.ndarray: x, y, tile = dims t = tile // latent_factor x = x // latent_factor @@ -46,27 +45,29 @@ def get_tile_latents(full_latents: np.ndarray, dims: Tuple[int, int, int]) -> np return full_latents[:, :, y:yt, x:xt] -def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams): +def load_pipeline( + pipeline: DiffusionPipeline, model: str, scheduler: Any, device: DeviceParams +): global last_pipeline_instance global last_pipeline_scheduler global last_pipeline_options options = (pipeline, model, device.provider) - if last_pipeline_instance != None and last_pipeline_options == options: - logger.debug('reusing existing diffusion pipeline') + if last_pipeline_instance is not None and last_pipeline_options == options: + logger.debug("reusing existing diffusion pipeline") pipe = last_pipeline_instance else: - logger.debug('unloading previous diffusion pipeline') + logger.debug("unloading previous diffusion pipeline") last_pipeline_instance = None last_pipeline_scheduler = None run_gc() - logger.debug('loading new diffusion pipeline from %s', model) + logger.debug("loading new diffusion pipeline from %s", model) scheduler = scheduler.from_pretrained( model, provider=device.provider, provider_options=device.options, - subfolder='scheduler', + subfolder="scheduler", ) pipe = pipeline.from_pretrained( model, @@ -76,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic scheduler=scheduler, ) - if device is not None and hasattr(pipe, 'to'): + if device is not None and hasattr(pipe, "to"): pipe = pipe.to(device) last_pipeline_instance = pipe @@ -84,15 +85,15 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, scheduler: Any, devic last_pipeline_scheduler = scheduler if last_pipeline_scheduler != scheduler: - logger.debug('loading new diffusion scheduler') + logger.debug("loading new diffusion scheduler") scheduler = scheduler.from_pretrained( model, provider=device.provider, provider_options=device.options, - subfolder='scheduler', + subfolder="scheduler", ) - if device is not None and hasattr(scheduler, 'to'): + if device is not None and hasattr(scheduler, "to"): scheduler = scheduler.to(device) pipe.scheduler = scheduler diff --git a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py index 8e25d4d9..b696d349 100644 --- a/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/api/onnx_web/diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -1,28 +1,16 @@ -from diffusers import ( - DDPMScheduler, - OnnxRuntimeModel, - StableDiffusionUpscalePipeline, -) -from diffusers.pipeline_utils import ( - ImagePipelineOutput, -) from logging import getLogger -from PIL import Image -from typing import ( - Any, - Callable, - List, - Optional, - Union, -) +from typing import Any, Callable, List, Optional, Union import numpy as np import torch +from diffusers import DDPMScheduler, OnnxRuntimeModel, StableDiffusionUpscalePipeline +from diffusers.pipeline_utils import ImagePipelineOutput +from PIL import Image logger = getLogger(__name__) -num_channels_latents = 4 # self.vae.config.latent_channels -unet_in_channels = 7 # self.unet.config.in_channels +num_channels_latents = 4 # self.vae.config.latent_channels +unet_in_channels = 7 # self.unet.config.in_channels ### # This is based on a combination of the ONNX img2img pipeline and the PyTorch upscale pipeline: @@ -30,6 +18,7 @@ unet_in_channels = 7 # self.unet.config.in_channels # https://github.com/huggingface/diffusers/blob/v0.11.1/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py ### + def preprocess(image): if isinstance(image, torch.Tensor): return image @@ -63,8 +52,15 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): scheduler: Any, max_noise_level: int = 350, ): - super().__init__(vae, text_encoder, tokenizer, unet, - low_res_scheduler, scheduler, max_noise_level) + super().__init__( + vae, + text_encoder, + tokenizer, + unet, + low_res_scheduler, + scheduler, + max_noise_level, + ) def __call__( self, @@ -96,7 +92,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): # 3. Encode input prompt text_embeddings = self._encode_prompt( - prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, ) # 4. Preprocess image @@ -111,7 +111,9 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): text_embeddings_dtype = torch.float32 noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) - noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings_dtype) + noise = torch.randn( + image.shape, generator=generator, device=device, dtype=text_embeddings_dtype + ) image = self.low_res_scheduler.add_noise(image, noise, noise_level) batch_multiplier = 2 if do_classifier_free_guidance else 1 @@ -135,11 +137,11 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): num_channels_image = image.shape[1] if num_channels_latents + num_channels_image != unet_in_channels: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet` expects" - f" {unet_in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_image`: {num_channels_image} " - f" = {num_channels_latents+num_channels_image}. Please verify the config of" - " `pipeline.unet` or your `image` input." + "Incorrect configuration settings! The config of `pipeline.unet`" + f" expects {unet_in_channels} but received `num_channels_latents`:" + f" {num_channels_latents} + `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the" + " config of `pipeline.unet` or your `image` input." ) # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -150,10 +152,16 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = ( + np.concatenate([latents] * 2) + if do_classifier_free_guidance + else latents + ) # concat latents, mask, masked_image_latents in the channel dimension - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) latent_model_input = np.concatenate([latent_model_input, image], axis=1) # timestep to tensor @@ -164,19 +172,25 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): sample=latent_model_input, timestep=timestep, encoder_hidden_states=text_embeddings, - class_labels=noise_level + class_labels=noise_level, )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) @@ -200,7 +214,14 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): image = image.transpose((0, 2, 3, 1)) return image - def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + ): batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( @@ -211,13 +232,20 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" + "The following part of your input was truncated because CLIP can only" + f" handle sequences up to {self.tokenizer.model_max_length} tokens:" + f" {removed_text}" ) # no positional arguments to text_encoder @@ -240,16 +268,17 @@ class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." + "`negative_prompt` should be the same type to `prompt`, but got" + f" {type(negative_prompt)} != {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." + f"`negative_prompt`: {negative_prompt} has batch size" + f" {len(negative_prompt)}, but `prompt`: {prompt} has batch size" + f" {batch_size}. Please make sure that passed `negative_prompt`" + " matches the batch size of `prompt`." ) else: uncond_tokens = negative_prompt diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index efe27803..00ad3479 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -1,41 +1,17 @@ -from diffusers import ( - OnnxStableDiffusionPipeline, - OnnxStableDiffusionImg2ImgPipeline, -) from logging import getLogger -from PIL import Image, ImageChops from typing import Any -from ..chain import ( - upscale_outpaint, -) -from ..device_pool import ( - JobContext, -) -from ..params import ( - ImageParams, - Border, - Size, - StageParams, -) -from ..output import ( - save_image, - save_params, -) -from ..upscale import ( - run_upscale_correction, - UpscaleParams, -) -from ..utils import ( - run_gc, - ServerContext, -) -from .load import ( - get_latents_from_seed, - load_pipeline, -) - import numpy as np +from diffusers import OnnxStableDiffusionImg2ImgPipeline, OnnxStableDiffusionPipeline +from PIL import Image, ImageChops + +from ..chain import upscale_outpaint +from ..device_pool import JobContext +from ..output import save_image, save_params +from ..params import Border, ImageParams, Size, StageParams +from ..upscale import UpscaleParams, run_upscale_correction +from ..utils import ServerContext, run_gc +from .load import get_latents_from_seed, load_pipeline logger = getLogger(__name__) @@ -46,10 +22,11 @@ def run_txt2img_pipeline( params: ImageParams, size: Size, output: str, - upscale: UpscaleParams + upscale: UpscaleParams, ) -> None: - pipe = load_pipeline(OnnxStableDiffusionPipeline, - params.model, params.scheduler, job.get_device()) + pipe = load_pipeline( + OnnxStableDiffusionPipeline, params.model, params.scheduler, job.get_device() + ) latents = get_latents_from_seed(params.seed, size) rng = np.random.RandomState(params.seed) @@ -68,7 +45,8 @@ def run_txt2img_pipeline( ) image = result.images[0] image = run_upscale_correction( - job, server, StageParams(), params, image, upscale=upscale) + job, server, StageParams(), params, image, upscale=upscale + ) dest = save_image(server, output, image) save_params(server, output, params, size, upscale=upscale) @@ -77,7 +55,7 @@ def run_txt2img_pipeline( del result run_gc() - logger.info('finished txt2img job: %s', dest) + logger.info("finished txt2img job: %s", dest) def run_img2img_pipeline( @@ -89,8 +67,12 @@ def run_img2img_pipeline( source_image: Image.Image, strength: float, ) -> None: - pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline, - params.model, params.scheduler, job.get_device()) + pipe = load_pipeline( + OnnxStableDiffusionImg2ImgPipeline, + params.model, + params.scheduler, + job.get_device(), + ) rng = np.random.RandomState(params.seed) @@ -107,7 +89,8 @@ def run_img2img_pipeline( ) image = result.images[0] image = run_upscale_correction( - job, server, StageParams(), params, image, upscale=upscale) + job, server, StageParams(), params, image, upscale=upscale + ) dest = save_image(server, output, image) size = Size(*source_image.size) @@ -117,7 +100,7 @@ def run_img2img_pipeline( del result run_gc() - logger.info('finished img2img job: %s', dest) + logger.info("finished img2img job: %s", dest) def run_inpaint_pipeline( @@ -151,16 +134,14 @@ def run_inpaint_pipeline( mask_filter=mask_filter, noise_source=noise_source, ) - logger.info('applying mask filter and generating noise source') + logger.info("applying mask filter and generating noise source") if image.size == source_image.size: image = ImageChops.blend(source_image, image, strength) else: - logger.info( - 'output image size does not match source, skipping post-blend') + logger.info("output image size does not match source, skipping post-blend") - image = run_upscale_correction( - job, server, stage, params, image, upscale=upscale) + image = run_upscale_correction(job, server, stage, params, image, upscale=upscale) dest = save_image(server, output, image) save_params(server, output, params, size, upscale=upscale, border=border) @@ -168,7 +149,7 @@ def run_inpaint_pipeline( del image run_gc() - logger.info('finished inpaint job: %s', dest) + logger.info("finished inpaint job: %s", dest) def run_upscale_pipeline( @@ -185,7 +166,8 @@ def run_upscale_pipeline( stage = StageParams() image = run_upscale_correction( - job, server, stage, params, source_image, upscale=upscale) + job, server, stage, params, source_image, upscale=upscale + ) dest = save_image(server, output, image) save_params(server, output, params, size, upscale=upscale) @@ -193,4 +175,4 @@ def run_upscale_pipeline( del image run_gc() - logger.info('finished upscale job: %s', dest) + logger.info("finished upscale job: %s", dest) diff --git a/api/onnx_web/image.py b/api/onnx_web/image.py index 2c3678ba..b2c99695 100644 --- a/api/onnx_web/image.py +++ b/api/onnx_web/image.py @@ -1,31 +1,31 @@ +import numpy as np from numpy import random from PIL import Image, ImageChops, ImageFilter -import numpy as np - -from .params import ( - Border, - Point, -) +from .params import Border, Point def get_pixel_index(x: int, y: int, width: int) -> int: return (y * width) + x -def mask_filter_none(mask_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image: +def mask_filter_none( + mask_image: Image.Image, dims: Point, origin: Point, fill="white", **kw +) -> Image.Image: width, height = dims - noise = Image.new('RGB', (width, height), fill) + noise = Image.new("RGB", (width, height), fill) noise.paste(mask_image, origin) return noise -def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image: - ''' +def mask_filter_gaussian_multiply( + mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw +) -> Image.Image: + """ Gaussian blur with multiply, source image centered on white canvas. - ''' + """ noise = mask_filter_none(mask_image, dims, origin) for i in range(rounds): @@ -35,10 +35,12 @@ def mask_filter_gaussian_multiply(mask_image: Image.Image, dims: Point, origin: return noise -def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image: - ''' +def mask_filter_gaussian_screen( + mask_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw +) -> Image.Image: + """ Gaussian blur, source image centered on white canvas. - ''' + """ noise = mask_filter_none(mask_image, dims, origin) for i in range(rounds): @@ -48,33 +50,39 @@ def mask_filter_gaussian_screen(mask_image: Image.Image, dims: Point, origin: Po return noise -def noise_source_fill_edge(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image: - ''' +def noise_source_fill_edge( + source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw +) -> Image.Image: + """ Identity transform, source image centered on white canvas. - ''' + """ width, height = dims - noise = Image.new('RGB', (width, height), fill) + noise = Image.new("RGB", (width, height), fill) noise.paste(source_image, origin) return noise -def noise_source_fill_mask(source_image: Image.Image, dims: Point, origin: Point, fill='white', **kw) -> Image.Image: - ''' +def noise_source_fill_mask( + source_image: Image.Image, dims: Point, origin: Point, fill="white", **kw +) -> Image.Image: + """ Fill the whole canvas, no source or noise. - ''' + """ width, height = dims - noise = Image.new('RGB', (width, height), fill) + noise = Image.new("RGB", (width, height), fill) return noise -def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw) -> Image.Image: - ''' +def noise_source_gaussian( + source_image: Image.Image, dims: Point, origin: Point, rounds=3, **kw +) -> Image.Image: + """ Gaussian blur, source image centered on white canvas. - ''' + """ noise = noise_source_uniform(source_image, dims, origin) noise.paste(source_image, origin) @@ -84,7 +92,9 @@ def noise_source_gaussian(source_image: Image.Image, dims: Point, origin: Point, return noise -def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image: +def noise_source_uniform( + source_image: Image.Image, dims: Point, origin: Point, **kw +) -> Image.Image: width, height = dims size = width * height @@ -92,21 +102,19 @@ def noise_source_uniform(source_image: Image.Image, dims: Point, origin: Point, noise_g = random.uniform(0, 256, size=size) noise_b = random.uniform(0, 256, size=size) - noise = Image.new('RGB', (width, height)) + noise = Image.new("RGB", (width, height)) for x in range(width): for y in range(height): i = get_pixel_index(x, y, width) - noise.putpixel((x, y), ( - int(noise_r[i]), - int(noise_g[i]), - int(noise_b[i]) - )) + noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) return noise -def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image: +def noise_source_normal( + source_image: Image.Image, dims: Point, origin: Point, **kw +) -> Image.Image: width, height = dims size = width * height @@ -114,21 +122,19 @@ def noise_source_normal(source_image: Image.Image, dims: Point, origin: Point, * noise_g = random.normal(128, 32, size=size) noise_b = random.normal(128, 32, size=size) - noise = Image.new('RGB', (width, height)) + noise = Image.new("RGB", (width, height)) for x in range(width): for y in range(height): i = get_pixel_index(x, y, width) - noise.putpixel((x, y), ( - int(noise_r[i]), - int(noise_g[i]), - int(noise_b[i]) - )) + noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) return noise -def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point, **kw) -> Image.Image: +def noise_source_histogram( + source_image: Image.Image, dims: Point, origin: Point, **kw +) -> Image.Image: r, g, b = source_image.split() width, height = dims size = width * height @@ -137,35 +143,34 @@ def noise_source_histogram(source_image: Image.Image, dims: Point, origin: Point hist_g = g.histogram() hist_b = b.histogram() - noise_r = random.choice(256, p=np.divide( - np.copy(hist_r), np.sum(hist_r)), size=size) - noise_g = random.choice(256, p=np.divide( - np.copy(hist_g), np.sum(hist_g)), size=size) - noise_b = random.choice(256, p=np.divide( - np.copy(hist_b), np.sum(hist_b)), size=size) + noise_r = random.choice( + 256, p=np.divide(np.copy(hist_r), np.sum(hist_r)), size=size + ) + noise_g = random.choice( + 256, p=np.divide(np.copy(hist_g), np.sum(hist_g)), size=size + ) + noise_b = random.choice( + 256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size + ) - noise = Image.new('RGB', (width, height)) + noise = Image.new("RGB", (width, height)) for x in range(width): for y in range(height): i = get_pixel_index(x, y, width) - noise.putpixel((x, y), ( - noise_r[i], - noise_g[i], - noise_b[i] - )) + noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i])) return noise # very loosely based on https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/scripts/outpainting_mk_2.py#L175-L232 def expand_image( - source_image: Image.Image, - mask_image: Image.Image, - expand: Border, - fill='white', - noise_source=noise_source_histogram, - mask_filter=mask_filter_none, + source_image: Image.Image, + mask_image: Image.Image, + expand: Border, + fill="white", + noise_source=noise_source_histogram, + mask_filter=mask_filter_none, ): full_width = expand.left + source_image.width + expand.right full_height = expand.top + source_image.height + expand.bottom @@ -173,14 +178,13 @@ def expand_image( dims = (full_width, full_height) origin = (expand.left, expand.top) - full_source = Image.new('RGB', dims, fill) + full_source = Image.new("RGB", dims, fill) full_source.paste(source_image, origin) full_mask = mask_filter(mask_image, dims, origin, fill=fill) full_noise = noise_source(source_image, dims, origin, fill=fill) full_noise = ImageChops.multiply(full_noise, full_mask) - full_source = Image.composite( - full_noise, full_source, full_mask.convert('L')) + full_source = Image.composite(full_noise, full_source, full_mask.convert("L")) return (full_source, full_mask, full_noise, (full_width, full_height)) diff --git a/api/onnx_web/logging.py b/api/onnx_web/logging.py index 2e589846..b34277f7 100644 --- a/api/onnx_web/logging.py +++ b/api/onnx_web/logging.py @@ -1,14 +1,15 @@ from logging.config import dictConfig from os import environ, path + from yaml import safe_load -logging_path = environ.get('ONNX_WEB_LOGGING_PATH', './logging.yaml') +logging_path = environ.get("ONNX_WEB_LOGGING_PATH", "./logging.yaml") # setup logging config before anything else loads try: - if path.exists(logging_path): - with open(logging_path, 'r') as f: - config_logging = safe_load(f) - dictConfig(config_logging) + if path.exists(logging_path): + with open(logging_path, "r") as f: + config_logging = safe_load(f) + dictConfig(config_logging) except Exception as err: - print('error loading logging config: %s' % (err)) + print("error loading logging config: %s" % (err)) diff --git a/api/onnx_web/onnx/__init__.py b/api/onnx_web/onnx/__init__.py index 609d3473..9d30760b 100644 --- a/api/onnx_web/onnx/__init__.py +++ b/api/onnx_web/onnx/__init__.py @@ -1,4 +1 @@ -from .onnx_net import ( - OnnxImage, - OnnxNet, -) \ No newline at end of file +from .onnx_net import OnnxImage, OnnxNet diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 20b99098..9e6c1d7e 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -1,15 +1,14 @@ -from onnxruntime import InferenceSession from os import path from typing import Any, Optional import numpy as np import torch +from onnxruntime import InferenceSession -from ..utils import ( - ServerContext, -) +from ..utils import ServerContext -class OnnxImage(): + +class OnnxImage: def __init__(self, source) -> None: self.source = source self.data = self @@ -38,28 +37,27 @@ class OnnxImage(): return np.shape(self.source) -class OnnxNet(): - ''' +class OnnxNet: + """ Provides the RRDBNet interface using an ONNX session for DirectML acceleration. - ''' + """ def __init__( self, server: ServerContext, model: str, - provider: str = 'DmlExecutionProvider', + provider: str = "DmlExecutionProvider", provider_options: Optional[dict] = None, ) -> None: model_path = path.join(server.model_path, model) self.session = InferenceSession( - model_path, providers=[provider], provider_options=provider_options) + model_path, providers=[provider], provider_options=provider_options + ) def __call__(self, image: Any) -> Any: input_name = self.session.get_inputs()[0].name output_name = self.session.get_outputs()[0].name - output = self.session.run([output_name], { - input_name: image.cpu().numpy() - })[0] + output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0] return OnnxImage(output) def eval(self) -> None: diff --git a/api/onnx_web/output.py b/api/onnx_web/output.py index c71391e1..0b65641c 100644 --- a/api/onnx_web/output.py +++ b/api/onnx_web/output.py @@ -1,22 +1,14 @@ from hashlib import sha256 from json import dumps from logging import getLogger -from PIL import Image from struct import pack from time import time from typing import Any, Optional, Tuple -from .params import ( - Border, - ImageParams, - Param, - Size, - UpscaleParams, -) -from .utils import ( - base_join, - ServerContext, -) +from PIL import Image + +from .params import Border, ImageParams, Param, Size, UpscaleParams +from .utils import ServerContext, base_join logger = getLogger(__name__) @@ -25,13 +17,13 @@ def hash_value(sha, param: Param): if param is None: return elif isinstance(param, float): - sha.update(bytearray(pack('!f', param))) + sha.update(bytearray(pack("!f", param))) elif isinstance(param, int): - sha.update(bytearray(pack('!I', param))) + sha.update(bytearray(pack("!I", param))) elif isinstance(param, str): - sha.update(param.encode('utf-8')) + sha.update(param.encode("utf-8")) else: - logger.warn('cannot hash param: %s, %s', param, type(param)) + logger.warn("cannot hash param: %s, %s", param, type(param)) def json_params( @@ -42,22 +34,22 @@ def json_params( border: Optional[Border] = None, ) -> Any: json = { - 'output': output, - 'params': params.tojson(), + "output": output, + "params": params.tojson(), } if upscale is not None and border is not None: size = upscale.resize(size.add_border(border)) if upscale is not None: - json['upscale'] = upscale.tojson() + json["upscale"] = upscale.tojson() size = upscale.resize(size) if border is not None: - json['border'] = border.tojson() + json["border"] = border.tojson() size = size.add_border(border) - json['size'] = size.tojson() + json["size"] = size.tojson() return json @@ -67,7 +59,7 @@ def make_output_name( mode: str, params: ImageParams, size: Size, - extras: Optional[Tuple[Param]] = None + extras: Optional[Tuple[Param]] = None, ) -> str: now = int(time()) sha = sha256() @@ -87,13 +79,19 @@ def make_output_name( for param in extras: hash_value(sha, param) - return '%s_%s_%s_%s.%s' % (mode, params.seed, sha.hexdigest(), now, ctx.image_format) + return "%s_%s_%s_%s.%s" % ( + mode, + params.seed, + sha.hexdigest(), + now, + ctx.image_format, + ) def save_image(ctx: ServerContext, output: str, image: Image.Image) -> str: path = base_join(ctx.output_path, output) image.save(path, format=ctx.image_format) - logger.debug('saved output image to: %s', path) + logger.debug("saved output image to: %s", path) return path @@ -105,9 +103,9 @@ def save_params( upscale: Optional[UpscaleParams] = None, border: Optional[Border] = None, ) -> str: - path = base_join(ctx.output_path, '%s.json' % (output)) + path = base_join(ctx.output_path, "%s.json" % (output)) json = json_params(output, params, size, upscale=upscale, border=border) - with open(path, 'w') as f: + with open(path, "w") as f: f.write(dumps(json)) - logger.debug('saved image params to: %s', path) + logger.debug("saved image params to: %s", path) return path diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index b2592ee2..67793866 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -3,9 +3,9 @@ from typing import Any, Dict, Literal, Optional, Tuple, Union class SizeChart(IntEnum): - mini = 128 # small tile for very expensive models - half = 256 # half tile for outpainting - auto = 512 # auto tile size + mini = 128 # small tile for very expensive models + half = 256 # half tile for outpainting + auto = 512 # auto tile size hd1k = 2**10 hd2k = 2**11 hd4k = 2**12 @@ -26,14 +26,14 @@ class Border: self.bottom = bottom def __str__(self) -> str: - return '%s %s %s %s' % (self.left, self.top, self.right, self.bottom) + return "%s %s %s %s" % (self.left, self.top, self.right, self.bottom) def tojson(self): return { - 'left': self.left, - 'right': self.right, - 'top': self.top, - 'bottom': self.bottom, + "left": self.left, + "right": self.right, + "top": self.top, + "bottom": self.bottom, } @classmethod @@ -47,32 +47,37 @@ class Size: self.height = height def __str__(self) -> str: - return '%sx%s' % (self.width, self.height) + return "%sx%s" % (self.width, self.height) def add_border(self, border: Border): - return Size(border.left + self.width + border.right, border.top + self.height + border.right) + return Size( + border.left + self.width + border.right, + border.top + self.height + border.right, + ) def tojson(self) -> Dict[str, int]: return { - 'height': self.height, - 'width': self.width, + "height": self.height, + "width": self.width, } class DeviceParams: - def __init__(self, device: str, provider: str, options: Optional[dict] = None) -> None: + def __init__( + self, device: str, provider: str, options: Optional[dict] = None + ) -> None: self.device = device self.provider = provider self.options = options def __str__(self) -> str: - return '%s - %s (%s)' % (self.device, self.provider, self.options) + return "%s - %s (%s)" % (self.device, self.provider, self.options) def torch_device(self) -> str: - if self.device.startswith('cuda'): + if self.device.startswith("cuda"): return self.device else: - return 'cpu' + return "cpu" class ImageParams: @@ -84,7 +89,7 @@ class ImageParams: negative_prompt: Optional[str], cfg: float, steps: int, - seed: int + seed: int, ) -> None: self.model = model self.scheduler = scheduler @@ -96,20 +101,20 @@ class ImageParams: def tojson(self) -> Dict[str, Optional[Param]]: return { - 'model': self.model, - 'scheduler': self.scheduler.__name__, - 'seed': self.seed, - 'prompt': self.prompt, - 'cfg': self.cfg, - 'negativePrompt': self.negative_prompt, - 'steps': self.steps, + "model": self.model, + "scheduler": self.scheduler.__name__, + "seed": self.seed, + "prompt": self.prompt, + "cfg": self.cfg, + "negativePrompt": self.negative_prompt, + "steps": self.steps, } class StageParams: - ''' + """ Parameters for a chained pipeline stage - ''' + """ def __init__( self, @@ -123,7 +128,7 @@ class StageParams: self.outscale = outscale -class UpscaleParams(): +class UpscaleParams: def __init__( self, upscale_model: str, @@ -131,7 +136,7 @@ class UpscaleParams(): denoise: float = 0.5, faces=True, face_strength: float = 0.5, - format: Literal['onnx', 'pth'] = 'onnx', + format: Literal["onnx", "pth"] = "onnx", half=False, outscale: int = 1, scale: int = 4, @@ -170,8 +175,8 @@ class UpscaleParams(): def tojson(self): return { - 'model': self.upscale_model, - 'scale': self.scale, - 'outscale': self.outscale, + "model": self.upscale_model, + "scale": self.scale, + "outscale": self.outscale, # TODO: add more } diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 2e527e3b..c361c203 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -1,62 +1,61 @@ -from . import logging +import gc +from functools import cmp_to_key +from glob import glob +from io import BytesIO +from logging import getLogger +from os import makedirs, path +from typing import List, Tuple + +import numpy as np +import torch +import yaml from diffusers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, - EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, HeunDiscreteScheduler, + KarrasVeScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, - KarrasVeScheduler, LMSDiscreteScheduler, PNDMScheduler, ) from flask import Flask, jsonify, make_response, request, send_from_directory, url_for from flask_cors import CORS -from functools import cmp_to_key -from glob import glob -from io import BytesIO from jsonschema import validate -from logging import getLogger -from PIL import Image from onnxruntime import get_available_providers -from os import makedirs, path -from typing import List, Tuple - +from PIL import Image +from . import logging from .chain import ( + ChainPipeline, blend_img2img, blend_inpaint, correct_gfpgan, persist_disk, persist_s3, - reduce_thumbnail, reduce_crop, + reduce_thumbnail, source_noise, source_txt2img, upscale_outpaint, upscale_resrgan, upscale_stable_diffusion, - ChainPipeline, -) -from .device_pool import ( - DeviceParams, - DevicePoolExecutor, ) +from .device_pool import DevicePoolExecutor from .diffusion.run import ( run_img2img_pipeline, run_inpaint_pipeline, run_txt2img_pipeline, run_upscale_pipeline, ) -from .image import ( - # mask filters +from .image import ( # mask filters; noise sources mask_filter_gaussian_multiply, mask_filter_gaussian_screen, mask_filter_none, - # noise sources noise_source_fill_edge, noise_source_fill_mask, noise_source_gaussian, @@ -64,35 +63,20 @@ from .image import ( noise_source_normal, noise_source_uniform, ) -from .output import ( - json_params, - make_output_name, -) -from .params import ( - Border, - DeviceParams, - ImageParams, - Size, - StageParams, - UpscaleParams, -) +from .output import json_params, make_output_name +from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams from .utils import ( + ServerContext, base_join, - is_debug, get_and_clamp_float, get_and_clamp_int, get_from_list, get_from_map, get_not_empty, get_size, - ServerContext, + is_debug, ) -import gc -import numpy as np -import torch -import yaml - logger = getLogger(__name__) # config caching @@ -100,53 +84,53 @@ config_params = {} # pipeline params platform_providers = { - 'amd': 'DmlExecutionProvider', - 'cpu': 'CPUExecutionProvider', - 'cuda': 'CUDAExecutionProvider', - 'directml': 'DmlExecutionProvider', - 'nvidia': 'CUDAExecutionProvider', - 'rocm': 'ROCMExecutionProvider', + "amd": "DmlExecutionProvider", + "cpu": "CPUExecutionProvider", + "cuda": "CUDAExecutionProvider", + "directml": "DmlExecutionProvider", + "nvidia": "CUDAExecutionProvider", + "rocm": "ROCMExecutionProvider", } pipeline_schedulers = { - 'ddim': DDIMScheduler, - 'ddpm': DDPMScheduler, - 'dpm-multi': DPMSolverMultistepScheduler, - 'dpm-single': DPMSolverSinglestepScheduler, - 'euler': EulerDiscreteScheduler, - 'euler-a': EulerAncestralDiscreteScheduler, - 'heun': HeunDiscreteScheduler, - 'k-dpm-2-a': KDPM2AncestralDiscreteScheduler, - 'k-dpm-2': KDPM2DiscreteScheduler, - 'karras-ve': KarrasVeScheduler, - 'lms-discrete': LMSDiscreteScheduler, - 'pndm': PNDMScheduler, + "ddim": DDIMScheduler, + "ddpm": DDPMScheduler, + "dpm-multi": DPMSolverMultistepScheduler, + "dpm-single": DPMSolverSinglestepScheduler, + "euler": EulerDiscreteScheduler, + "euler-a": EulerAncestralDiscreteScheduler, + "heun": HeunDiscreteScheduler, + "k-dpm-2-a": KDPM2AncestralDiscreteScheduler, + "k-dpm-2": KDPM2DiscreteScheduler, + "karras-ve": KarrasVeScheduler, + "lms-discrete": LMSDiscreteScheduler, + "pndm": PNDMScheduler, } noise_sources = { - 'fill-edge': noise_source_fill_edge, - 'fill-mask': noise_source_fill_mask, - 'gaussian': noise_source_gaussian, - 'histogram': noise_source_histogram, - 'normal': noise_source_normal, - 'uniform': noise_source_uniform, + "fill-edge": noise_source_fill_edge, + "fill-mask": noise_source_fill_mask, + "gaussian": noise_source_gaussian, + "histogram": noise_source_histogram, + "normal": noise_source_normal, + "uniform": noise_source_uniform, } mask_filters = { - 'none': mask_filter_none, - 'gaussian-multiply': mask_filter_gaussian_multiply, - 'gaussian-screen': mask_filter_gaussian_screen, + "none": mask_filter_none, + "gaussian-multiply": mask_filter_gaussian_multiply, + "gaussian-screen": mask_filter_gaussian_screen, } chain_stages = { - 'blend-img2img': blend_img2img, - 'blend-inpaint': blend_inpaint, - 'correct-gfpgan': correct_gfpgan, - 'persist-disk': persist_disk, - 'persist-s3': persist_s3, - 'reduce-crop': reduce_crop, - 'reduce-thumbnail': reduce_thumbnail, - 'source-noise': source_noise, - 'source-txt2img': source_txt2img, - 'upscale-outpaint': upscale_outpaint, - 'upscale-resrgan': upscale_resrgan, - 'upscale-stable-diffusion': upscale_stable_diffusion, + "blend-img2img": blend_img2img, + "blend-inpaint": blend_inpaint, + "correct-gfpgan": correct_gfpgan, + "persist-disk": persist_disk, + "persist-s3": persist_s3, + "reduce-crop": reduce_crop, + "reduce-thumbnail": reduce_thumbnail, + "source-noise": source_noise, + "source-txt2img": source_txt2img, + "upscale-outpaint": upscale_outpaint, + "upscale-resrgan": upscale_resrgan, + "upscale-stable-diffusion": upscale_stable_diffusion, } # Available ORT providers @@ -158,7 +142,7 @@ correction_models = [] upscaling_models = [] -def get_config_value(key: str, subkey: str = 'default'): +def get_config_value(key: str, subkey: str = "default"): return config_params.get(key).get(subkey) @@ -174,7 +158,7 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: user = request.remote_addr # platform stuff - device_name = request.args.get('platform', available_platforms[0].device) + device_name = request.args.get("platform", available_platforms[0].device) device = None for platform in available_platforms: @@ -182,78 +166,101 @@ def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: device = available_platforms[0] if device is None: - raise Exception('unknown device') + raise Exception("unknown device") # pipeline stuff - model = get_not_empty(request.args, 'model', get_config_value('model')) + model = get_not_empty(request.args, "model", get_config_value("model")) model_path = get_model_path(model) - scheduler = get_from_map(request.args, 'scheduler', - pipeline_schedulers, get_config_value('scheduler')) + scheduler = get_from_map( + request.args, "scheduler", pipeline_schedulers, get_config_value("scheduler") + ) # image params - prompt = get_not_empty(request.args, - 'prompt', get_config_value('prompt')) - negative_prompt = request.args.get('negativePrompt', None) + prompt = get_not_empty(request.args, "prompt", get_config_value("prompt")) + negative_prompt = request.args.get("negativePrompt", None) - if negative_prompt is not None and negative_prompt.strip() == '': + if negative_prompt is not None and negative_prompt.strip() == "": negative_prompt = None cfg = get_and_clamp_float( - request.args, 'cfg', - get_config_value('cfg'), - get_config_value('cfg', 'max'), - get_config_value('cfg', 'min')) + request.args, + "cfg", + get_config_value("cfg"), + get_config_value("cfg", "max"), + get_config_value("cfg", "min"), + ) steps = get_and_clamp_int( - request.args, 'steps', - get_config_value('steps'), - get_config_value('steps', 'max'), - get_config_value('steps', 'min')) + request.args, + "steps", + get_config_value("steps"), + get_config_value("steps", "max"), + get_config_value("steps", "min"), + ) height = get_and_clamp_int( - request.args, 'height', - get_config_value('height'), - get_config_value('height', 'max'), - get_config_value('height', 'min')) + request.args, + "height", + get_config_value("height"), + get_config_value("height", "max"), + get_config_value("height", "min"), + ) width = get_and_clamp_int( - request.args, 'width', - get_config_value('width'), - get_config_value('width', 'max'), - get_config_value('width', 'min')) + request.args, + "width", + get_config_value("width"), + get_config_value("width", "max"), + get_config_value("width", "min"), + ) - seed = int(request.args.get('seed', -1)) + seed = int(request.args.get("seed", -1)) if seed == -1: seed = np.random.randint(np.iinfo(np.int32).max) - logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", - user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt) + logger.info( + "request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", + user, + steps, + scheduler.__name__, + model_path, + device.provider, + width, + height, + cfg, + seed, + prompt, + ) - params = ImageParams(model_path, scheduler, prompt, - negative_prompt, cfg, steps, seed) + params = ImageParams( + model_path, scheduler, prompt, negative_prompt, cfg, steps, seed + ) size = Size(width, height) return (device, params, size) def border_from_request() -> Border: - left = get_and_clamp_int(request.args, 'left', 0, - get_config_value('width', 'max'), 0) - right = get_and_clamp_int(request.args, 'right', - 0, get_config_value('width', 'max'), 0) - top = get_and_clamp_int(request.args, 'top', 0, - get_config_value('height', 'max'), 0) + left = get_and_clamp_int( + request.args, "left", 0, get_config_value("width", "max"), 0 + ) + right = get_and_clamp_int( + request.args, "right", 0, get_config_value("width", "max"), 0 + ) + top = get_and_clamp_int( + request.args, "top", 0, get_config_value("height", "max"), 0 + ) bottom = get_and_clamp_int( - request.args, 'bottom', 0, get_config_value('height', 'max'), 0) + request.args, "bottom", 0, get_config_value("height", "max"), 0 + ) return Border(left, right, top, bottom) def upscale_from_request() -> UpscaleParams: - denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0) - scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1) - outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1) - upscaling = get_from_list(request.args, 'upscaling', upscaling_models) - correction = get_from_list(request.args, 'correction', correction_models) - faces = get_not_empty(request.args, 'faces', 'false') == 'true' - face_strength = get_and_clamp_float( - request.args, 'faceStrength', 0.5, 1.0, 0.0) + denoise = get_and_clamp_float(request.args, "denoise", 0.5, 1.0, 0.0) + scale = get_and_clamp_int(request.args, "scale", 1, 4, 1) + outscale = get_and_clamp_int(request.args, "outscale", 1, 4, 1) + upscaling = get_from_list(request.args, "upscaling", upscaling_models) + correction = get_from_list(request.args, "correction", correction_models) + faces = get_not_empty(request.args, "faces", "false") == "true" + face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) return UpscaleParams( upscaling, @@ -261,7 +268,7 @@ def upscale_from_request() -> UpscaleParams: denoise=denoise, faces=faces, face_strength=face_strength, - format='onnx', + format="onnx", outscale=outscale, scale=scale, ) @@ -269,7 +276,7 @@ def upscale_from_request() -> UpscaleParams: def check_paths(context: ServerContext): if not path.exists(context.model_path): - raise RuntimeError('model path must exist') + raise RuntimeError("model path must exist") if not path.exists(context.output_path): makedirs(context.output_path) @@ -286,35 +293,41 @@ def load_models(context: ServerContext): global correction_models global upscaling_models - diffusion_models = [get_model_name(f) for f in glob( - path.join(context.model_path, 'diffusion-*'))] - diffusion_models.extend([ - get_model_name(f) for f in glob(path.join(context.model_path, 'stable-diffusion-*'))]) + diffusion_models = [ + get_model_name(f) for f in glob(path.join(context.model_path, "diffusion-*")) + ] + diffusion_models.extend( + [ + get_model_name(f) + for f in glob(path.join(context.model_path, "stable-diffusion-*")) + ] + ) diffusion_models = list(set(diffusion_models)) diffusion_models.sort() correction_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, 'correction-*'))] + get_model_name(f) for f in glob(path.join(context.model_path, "correction-*")) + ] correction_models = list(set(correction_models)) correction_models.sort() upscaling_models = [ - get_model_name(f) for f in glob(path.join(context.model_path, 'upscaling-*'))] + get_model_name(f) for f in glob(path.join(context.model_path, "upscaling-*")) + ] upscaling_models = list(set(upscaling_models)) upscaling_models.sort() def load_params(context: ServerContext): global config_params - params_file = path.join(context.params_path, 'params.json') - with open(params_file, 'r') as f: + params_file = path.join(context.params_path, "params.json") + with open(params_file, "r") as f: config_params = yaml.safe_load(f) - if 'platform' in config_params and context.default_platform is not None: - logger.info('overriding default platform to %s', - context.default_platform) - config_platform = config_params.get('platform') - config_platform['default'] = context.default_platform + if "platform" in config_params and context.default_platform is not None: + logger.info("overriding default platform to %s", context.default_platform) + config_platform = config_params.get("platform") + config_platform["default"] = context.default_platform def load_platforms(): @@ -323,30 +336,42 @@ def load_platforms(): providers = get_available_providers() for potential in platform_providers: - if platform_providers[potential] in providers and potential not in context.block_platforms: - if potential == 'cuda': + if ( + platform_providers[potential] in providers + and potential not in context.block_platforms + ): + if potential == "cuda": for i in range(torch.cuda.device_count()): - available_platforms.append(DeviceParams(potential, platform_providers[potential], { - 'device_id': i, - })) + available_platforms.append( + DeviceParams( + potential, + platform_providers[potential], + { + "device_id": i, + }, + ) + ) else: - available_platforms.append(DeviceParams( - potential, platform_providers[potential])) + available_platforms.append( + DeviceParams(potential, platform_providers[potential]) + ) # make sure CPU is last on the list def cpu_last(a: DeviceParams, b: DeviceParams): - if a.device == 'cpu' and b.device == 'cpu': + if a.device == "cpu" and b.device == "cpu": return 0 - if a.device == 'cpu': + if a.device == "cpu": return 1 return -1 available_platforms = sorted(available_platforms, key=cmp_to_key(cpu_last)) - logger.info('available acceleration platforms: %s', - ', '.join([str(p) for p in available_platforms])) + logger.info( + "available acceleration platforms: %s", + ", ".join([str(p) for p in available_platforms]), + ) context = ServerContext.from_environ() @@ -365,16 +390,22 @@ if is_debug(): def ready_reply(ready: bool, progress: int = 0): - return jsonify({ - 'progress': progress, - 'ready': ready, - }) + return jsonify( + { + "progress": progress, + "ready": ready, + } + ) def error_reply(err: str): - response = make_response(jsonify({ - 'error': err, - })) + response = make_response( + jsonify( + { + "error": err, + } + ) + ) response.status_code = 400 return response @@ -383,151 +414,154 @@ def get_model_path(model: str): return base_join(context.model_path, model) -def serve_bundle_file(filename='index.html'): - return send_from_directory(path.join('..', context.bundle_path), filename) +def serve_bundle_file(filename="index.html"): + return send_from_directory(path.join("..", context.bundle_path), filename) # routes -@app.route('/') +@app.route("/") def index(): return serve_bundle_file() -@app.route('/') +@app.route("/") def index_path(filename): return serve_bundle_file(filename) -@app.route('/api') +@app.route("/api") def introspect(): return { - 'name': 'onnx-web', - 'routes': [{ - 'path': url_from_rule(rule), - 'methods': list(rule.methods).sort() - } for rule in app.url_map.iter_rules()] + "name": "onnx-web", + "routes": [ + {"path": url_from_rule(rule), "methods": list(rule.methods).sort()} + for rule in app.url_map.iter_rules() + ], } -@app.route('/api/settings/masks') +@app.route("/api/settings/masks") def list_mask_filters(): return jsonify(list(mask_filters.keys())) -@app.route('/api/settings/models') +@app.route("/api/settings/models") def list_models(): - return jsonify({ - 'diffusion': diffusion_models, - 'correction': correction_models, - 'upscaling': upscaling_models, - }) + return jsonify( + { + "diffusion": diffusion_models, + "correction": correction_models, + "upscaling": upscaling_models, + } + ) -@app.route('/api/settings/noises') +@app.route("/api/settings/noises") def list_noise_sources(): return jsonify(list(noise_sources.keys())) -@app.route('/api/settings/params') +@app.route("/api/settings/params") def list_params(): return jsonify(config_params) -@app.route('/api/settings/platforms') +@app.route("/api/settings/platforms") def list_platforms(): return jsonify([p.device for p in available_platforms]) -@app.route('/api/settings/schedulers') +@app.route("/api/settings/schedulers") def list_schedulers(): return jsonify(list(pipeline_schedulers.keys())) -@app.route('/api/img2img', methods=['POST']) +@app.route("/api/img2img", methods=["POST"]) def img2img(): - if 'source' not in request.files: - return error_reply('source image is required') + if "source" not in request.files: + return error_reply("source image is required") - source_file = request.files.get('source') - source_image = Image.open(BytesIO(source_file.read())).convert('RGB') + source_file = request.files.get("source") + source_image = Image.open(BytesIO(source_file.read())).convert("RGB") device, params, size = pipeline_from_request() upscale = upscale_from_request() strength = get_and_clamp_float( request.args, - 'strength', - get_config_value('strength'), - get_config_value('strength', 'max'), - get_config_value('strength', 'min')) + "strength", + get_config_value("strength"), + get_config_value("strength", "max"), + get_config_value("strength", "min"), + ) - output = make_output_name( - context, - 'img2img', - params, - size, - extras=(strength,)) + output = make_output_name(context, "img2img", params, size, extras=(strength,)) logger.info("img2img job queued for: %s", output) source_image.thumbnail((size.width, size.height)) - executor.submit(output, run_img2img_pipeline, - context, params, output, upscale, source_image, strength) + executor.submit( + output, + run_img2img_pipeline, + context, + params, + output, + upscale, + source_image, + strength, + ) return jsonify(json_params(output, params, size, upscale=upscale)) -@app.route('/api/txt2img', methods=['POST']) +@app.route("/api/txt2img", methods=["POST"]) def txt2img(): device, params, size = pipeline_from_request() upscale = upscale_from_request() - output = make_output_name( - context, - 'txt2img', - params, - size) + output = make_output_name(context, "txt2img", params, size) logger.info("txt2img job queued for: %s", output) executor.submit( - output, run_txt2img_pipeline, context, params, size, output, upscale) + output, run_txt2img_pipeline, context, params, size, output, upscale + ) return jsonify(json_params(output, params, size, upscale=upscale)) -@app.route('/api/inpaint', methods=['POST']) +@app.route("/api/inpaint", methods=["POST"]) def inpaint(): - if 'source' not in request.files: - return error_reply('source image is required') + if "source" not in request.files: + return error_reply("source image is required") - if 'mask' not in request.files: - return error_reply('mask image is required') + if "mask" not in request.files: + return error_reply("mask image is required") - source_file = request.files.get('source') - source_image = Image.open(BytesIO(source_file.read())).convert('RGB') + source_file = request.files.get("source") + source_image = Image.open(BytesIO(source_file.read())).convert("RGB") - mask_file = request.files.get('mask') - mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB') + mask_file = request.files.get("mask") + mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB") device, params, size = pipeline_from_request() expand = border_from_request() upscale = upscale_from_request() - fill_color = get_not_empty(request.args, 'fillColor', 'white') - mask_filter = get_from_map(request.args, 'filter', mask_filters, 'none') - noise_source = get_from_map( - request.args, 'noise', noise_sources, 'histogram') + fill_color = get_not_empty(request.args, "fillColor", "white") + mask_filter = get_from_map(request.args, "filter", mask_filters, "none") + noise_source = get_from_map(request.args, "noise", noise_sources, "histogram") strength = get_and_clamp_float( request.args, - 'strength', - get_config_value('strength'), - get_config_value('strength', 'max'), - get_config_value('strength', 'min')) + "strength", + get_config_value("strength"), + get_config_value("strength", "max"), + get_config_value("strength", "min"), + ) output = make_output_name( context, - 'inpaint', + "inpaint", params, size, extras=( @@ -539,7 +573,7 @@ def inpaint(): noise_source.__name__, strength, fill_color, - ) + ), ) logger.info("inpaint job queued for: %s", output) @@ -559,123 +593,131 @@ def inpaint(): noise_source, mask_filter, strength, - fill_color) + fill_color, + ) return jsonify(json_params(output, params, size, upscale=upscale, border=expand)) -@app.route('/api/upscale', methods=['POST']) +@app.route("/api/upscale", methods=["POST"]) def upscale(): - if 'source' not in request.files: - return error_reply('source image is required') + if "source" not in request.files: + return error_reply("source image is required") - source_file = request.files.get('source') - source_image = Image.open(BytesIO(source_file.read())).convert('RGB') + source_file = request.files.get("source") + source_image = Image.open(BytesIO(source_file.read())).convert("RGB") device, params, size = pipeline_from_request() upscale = upscale_from_request() - output = make_output_name( - context, - 'upscale', - params, - size) + output = make_output_name(context, "upscale", params, size) logger.info("upscale job queued for: %s", output) source_image.thumbnail((size.width, size.height)) - executor.submit(output, run_upscale_pipeline, - context, params, size, output, upscale, source_image) + executor.submit( + output, + run_upscale_pipeline, + context, + params, + size, + output, + upscale, + source_image, + ) return jsonify(json_params(output, params, size, upscale=upscale)) -@app.route('/api/chain', methods=['POST']) +@app.route("/api/chain", methods=["POST"]) def chain(): - logger.debug('chain pipeline request: %s, %s', - request.form.keys(), request.files.keys()) - body = request.form.get('chain') or request.files.get('chain') + logger.debug( + "chain pipeline request: %s, %s", request.form.keys(), request.files.keys() + ) + body = request.form.get("chain") or request.files.get("chain") if body is None: - return error_reply('chain pipeline must have a body') + return error_reply("chain pipeline must have a body") data = yaml.safe_load(body) - with open('./schema.yaml', 'r') as f: + with open("./schema.yaml", "r") as f: schema = yaml.safe_load(f.read()) - logger.info('validating chain request: %s against %s', data, schema) + logger.info("validating chain request: %s against %s", data, schema) validate(data, schema) # get defaults from the regular parameters device, params, size = pipeline_from_request() - output = make_output_name( - context, - 'chain', - params, - size) + output = make_output_name(context, "chain", params, size) pipeline = ChainPipeline() - for stage_data in data.get('stages', []): - callback = chain_stages[stage_data.get('type')] - kwargs = stage_data.get('params', {}) - logger.info('request stage: %s, %s', callback.__name__, kwargs) + for stage_data in data.get("stages", []): + callback = chain_stages[stage_data.get("type")] + kwargs = stage_data.get("params", {}) + logger.info("request stage: %s, %s", callback.__name__, kwargs) stage = StageParams( - stage_data.get('name', callback.__name__), - tile_size=get_size(kwargs.get('tile_size')), - outscale=get_and_clamp_int(kwargs, 'outscale', 1, 4), + stage_data.get("name", callback.__name__), + tile_size=get_size(kwargs.get("tile_size")), + outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), ) - if 'border' in kwargs: - border = Border.even(int(kwargs.get('border'))) - kwargs['border'] = border + if "border" in kwargs: + border = Border.even(int(kwargs.get("border"))) + kwargs["border"] = border - if 'upscale' in kwargs: - upscale = UpscaleParams(kwargs.get('upscale')) - kwargs['upscale'] = upscale + if "upscale" in kwargs: + upscale = UpscaleParams(kwargs.get("upscale")) + kwargs["upscale"] = upscale - stage_source_name = 'source:%s' % (stage.name) - stage_mask_name = 'mask:%s' % (stage.name) + stage_source_name = "source:%s" % (stage.name) + stage_mask_name = "mask:%s" % (stage.name) if stage_source_name in request.files: - logger.debug('loading source image %s for pipeline stage %s', - stage_source_name, stage.name) + logger.debug( + "loading source image %s for pipeline stage %s", + stage_source_name, + stage.name, + ) source_file = request.files.get(stage_source_name) - source_image = Image.open( - BytesIO(source_file.read())).convert('RGB') + source_image = Image.open(BytesIO(source_file.read())).convert("RGB") source_image = source_image.thumbnail((512, 512)) - kwargs['source_image'] = source_image + kwargs["source_image"] = source_image if stage_mask_name in request.files: - logger.debug('loading mask image %s for pipeline stage %s', - stage_mask_name, stage.name) + logger.debug( + "loading mask image %s for pipeline stage %s", + stage_mask_name, + stage.name, + ) mask_file = request.files.get(stage_mask_name) - mask_image = Image.open(BytesIO(mask_file.read())).convert('RGB') + mask_image = Image.open(BytesIO(mask_file.read())).convert("RGB") mask_image = mask_image.thumbnail((512, 512)) - kwargs['mask_image'] = mask_image + kwargs["mask_image"] = mask_image pipeline.append((callback, stage, kwargs)) - logger.info('running chain pipeline with %s stages', len(pipeline.stages)) + logger.info("running chain pipeline with %s stages", len(pipeline.stages)) # build and run chain pipeline - empty_source = Image.new('RGB', (size.width, size.height)) - executor.submit(output, pipeline, context, - params, empty_source, output=output, size=size) + empty_source = Image.new("RGB", (size.width, size.height)) + executor.submit( + output, pipeline, context, params, empty_source, output=output, size=size + ) return jsonify(json_params(output, params, size)) -@app.route('/api/cancel', methods=['PUT']) +@app.route("/api/cancel", methods=["PUT"]) def cancel(): - output_file = request.args.get('output', None) + output_file = request.args.get("output", None) cancel = executor.cancel(output_file) return ready_reply(cancel) -@app.route('/api/ready') +@app.route("/api/ready") def ready(): - output_file = request.args.get('output', None) + output_file = request.args.get("output", None) done, progress = executor.done(output_file) @@ -687,11 +729,13 @@ def ready(): return ready_reply(done, progress=progress) -@app.route('/api/status') +@app.route("/api/status") def status(): return jsonify(executor.status()) -@app.route('/output/') +@app.route("/output/") def output(filename: str): - return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False) + return send_from_directory( + path.join("..", context.output_path), filename, as_attachment=False + ) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 3d765c43..1862338b 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -1,24 +1,16 @@ from logging import getLogger + from PIL import Image from .chain import ( - correct_gfpgan, - upscale_stable_diffusion, - upscale_resrgan, ChainPipeline, + correct_gfpgan, + upscale_resrgan, + upscale_stable_diffusion, ) -from .device_pool import ( - JobContext, -) -from .params import ( - ImageParams, - SizeChart, - StageParams, - UpscaleParams, -) -from .utils import ( - ServerContext, -) +from .device_pool import JobContext +from .params import ImageParams, SizeChart, StageParams, UpscaleParams +from .utils import ServerContext logger = getLogger(__name__) @@ -32,27 +24,25 @@ def run_upscale_correction( *, upscale: UpscaleParams, ) -> Image.Image: - ''' + """ This is a convenience method for a chain pipeline that will run upscaling and correction, based on the `upscale` params. - ''' - logger.info('running upscaling and correction pipeline') + """ + logger.info("running upscaling and correction pipeline") chain = ChainPipeline() if upscale.scale > 1: - if 'esrgan' in upscale.upscale_model: - stage = StageParams(tile_size=stage.tile_size, - outscale=upscale.outscale) + if "esrgan" in upscale.upscale_model: + stage = StageParams(tile_size=stage.tile_size, outscale=upscale.outscale) chain.append((upscale_resrgan, stage, None)) - elif 'stable-diffusion' in upscale.upscale_model: + elif "stable-diffusion" in upscale.upscale_model: mini_tile = min(SizeChart.mini, stage.tile_size) stage = StageParams(tile_size=mini_tile, outscale=upscale.outscale) chain.append((upscale_stable_diffusion, stage, None)) if upscale.faces: - stage = StageParams(tile_size=stage.tile_size, - outscale=1) + stage = StageParams(tile_size=stage.tile_size, outscale=1) chain.append((correct_gfpgan, stage, None)) return chain(job, server, params, image, prompt=params.prompt, upscale=upscale) diff --git a/api/onnx_web/utils.py b/api/onnx_web/utils.py index 8387d189..473d424c 100644 --- a/api/onnx_web/utils.py +++ b/api/onnx_web/utils.py @@ -1,13 +1,11 @@ +import gc from logging import getLogger from os import environ, path from typing import Any, Dict, List, Optional, Union -import gc import torch -from .params import ( - SizeChart, -) +from .params import SizeChart logger = getLogger(__name__) @@ -15,15 +13,15 @@ logger = getLogger(__name__) class ServerContext: def __init__( self, - bundle_path: str = '.', - model_path: str = '.', - output_path: str = '.', - params_path: str = '.', - cors_origin: str = '*', + bundle_path: str = ".", + model_path: str = ".", + output_path: str = ".", + params_path: str = ".", + cors_origin: str = "*", num_workers: int = 1, block_platforms: List[str] = [], default_platform: str = None, - image_format: str = 'png', + image_format: str = "png", ) -> None: self.bundle_path = bundle_path self.model_path = model_path @@ -38,40 +36,39 @@ class ServerContext: @classmethod def from_environ(cls): return ServerContext( - bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH', - path.join('..', 'gui', 'out')), - model_path=environ.get('ONNX_WEB_MODEL_PATH', - path.join('..', 'models')), - output_path=environ.get( - 'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')), - params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'), - # others - cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','), - num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)), - block_platforms=environ.get( - 'ONNX_WEB_BLOCK_PLATFORMS', '').split(','), - default_platform=environ.get( - 'ONNX_WEB_DEFAULT_PLATFORM', None), - image_format=environ.get( - 'ONNX_WEB_IMAGE_FORMAT', 'png' + bundle_path=environ.get( + "ONNX_WEB_BUNDLE_PATH", path.join("..", "gui", "out") ), + model_path=environ.get("ONNX_WEB_MODEL_PATH", path.join("..", "models")), + output_path=environ.get("ONNX_WEB_OUTPUT_PATH", path.join("..", "outputs")), + params_path=environ.get("ONNX_WEB_PARAMS_PATH", "."), + # others + cors_origin=environ.get("ONNX_WEB_CORS_ORIGIN", "*").split(","), + num_workers=int(environ.get("ONNX_WEB_NUM_WORKERS", 1)), + block_platforms=environ.get("ONNX_WEB_BLOCK_PLATFORMS", "").split(","), + default_platform=environ.get("ONNX_WEB_DEFAULT_PLATFORM", None), + image_format=environ.get("ONNX_WEB_IMAGE_FORMAT", "png"), ) def base_join(base: str, tail: str) -> str: - tail_path = path.relpath(path.normpath(path.join('/', tail)), '/') + tail_path = path.relpath(path.normpath(path.join("/", tail)), "/") return path.join(base, tail_path) def is_debug() -> bool: - return environ.get('DEBUG') is not None + return environ.get("DEBUG") is not None -def get_and_clamp_float(args: Any, key: str, default_value: float, max_value: float, min_value=0.0) -> float: +def get_and_clamp_float( + args: Any, key: str, default_value: float, max_value: float, min_value=0.0 +) -> float: return min(max(float(args.get(key, default_value)), min_value), max_value) -def get_and_clamp_int(args: Any, key: str, default_value: int, max_value: int, min_value=1) -> int: +def get_and_clamp_int( + args: Any, key: str, default_value: int, max_value: int, min_value=1 +) -> int: return min(max(int(args.get(key, default_value)), min_value), max_value) @@ -80,7 +77,7 @@ def get_from_list(args: Any, key: str, values: List[Any]) -> Optional[Any]: if selected in values: return selected - logger.warn('invalid selection: %s', selected) + logger.warn("invalid selection: %s", selected) if len(values) > 0: return values[0] @@ -118,10 +115,10 @@ def get_size(val: Union[int, str, None]) -> SizeChart: return int(val) - raise Exception('invalid size') + raise Exception("invalid size") def run_gc(): - logger.debug('running garbage collection') + logger.debug("running garbage collection") gc.collect() torch.cuda.empty_cache() diff --git a/api/pyproject.toml b/api/pyproject.toml new file mode 100644 index 00000000..5d7bf33d --- /dev/null +++ b/api/pyproject.toml @@ -0,0 +1,2 @@ +[tool.isort] +profile = "black"