From 2913cd03826b9257abdb7a7660f08a6acc170958 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 1 Jul 2023 07:10:53 -0500 Subject: [PATCH] feat(api): make chain stages into classes with max tile size and step count estimate --- api/onnx_web/__init__.py | 2 +- api/onnx_web/chain/__init__.py | 80 +++--- api/onnx_web/chain/base.py | 9 +- api/onnx_web/chain/blend_highres.py | 0 api/onnx_web/chain/blend_img2img.py | 146 ++++++----- api/onnx_web/chain/blend_inpaint.py | 195 +++++++------- api/onnx_web/chain/blend_linear.py | 28 +- api/onnx_web/chain/blend_mask.py | 42 +-- api/onnx_web/chain/correct_codeformer.py | 42 +-- api/onnx_web/chain/correct_gfpgan.py | 118 +++++---- api/onnx_web/chain/persist_disk.py | 32 +-- api/onnx_web/chain/persist_s3.py | 54 ++-- api/onnx_web/chain/reduce_crop.py | 36 +-- api/onnx_web/chain/reduce_thumbnail.py | 36 +-- api/onnx_web/chain/source_noise.py | 42 +-- api/onnx_web/chain/source_s3.py | 52 ++-- api/onnx_web/chain/source_txt2img.py | 132 +++++----- api/onnx_web/chain/source_url.py | 44 ++-- api/onnx_web/chain/stage.py | 31 +++ api/onnx_web/{diffusers => chain}/upscale.py | 26 +- api/onnx_web/chain/upscale_bsrgan.py | 190 ++++++++------ api/onnx_web/chain/upscale_highres.py | 107 ++++---- api/onnx_web/chain/upscale_outpaint.py | 248 +++++++++--------- api/onnx_web/chain/upscale_resrgan.py | 127 ++++----- .../chain/upscale_stable_diffusion.py | 92 +++---- api/onnx_web/chain/upscale_swinir.py | 185 ++++++------- api/onnx_web/diffusers/run.py | 32 +-- api/onnx_web/image/utils.py | 4 +- api/onnx_web/server/api.py | 8 +- 29 files changed, 1121 insertions(+), 1019 deletions(-) delete mode 100644 api/onnx_web/chain/blend_highres.py create mode 100644 api/onnx_web/chain/stage.py rename api/onnx_web/{diffusers => chain}/upscale.py (80%) diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index a4f1cdaf..27bd382e 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -17,7 +17,7 @@ from .diffusers.run import ( run_upscale_pipeline, ) from .diffusers.stub_scheduler import StubScheduler -from .diffusers.upscale import stage_upscale_correction +from .chain.upscale import stage_upscale_correction from .image.utils import ( expand_image, ) diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index e7858bd2..4f3b5e9f 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,44 +1,44 @@ from .base import ChainPipeline, PipelineStage, StageCallback, StageParams -from .blend_img2img import blend_img2img -from .blend_inpaint import blend_inpaint -from .blend_linear import blend_linear -from .blend_mask import blend_mask -from .correct_codeformer import correct_codeformer -from .correct_gfpgan import correct_gfpgan -from .persist_disk import persist_disk -from .persist_s3 import persist_s3 -from .reduce_crop import reduce_crop -from .reduce_thumbnail import reduce_thumbnail -from .source_noise import source_noise -from .source_s3 import source_s3 -from .source_txt2img import source_txt2img -from .source_url import source_url -from .upscale_bsrgan import upscale_bsrgan -from .upscale_highres import upscale_highres -from .upscale_outpaint import upscale_outpaint -from .upscale_resrgan import upscale_resrgan -from .upscale_stable_diffusion import upscale_stable_diffusion -from .upscale_swinir import upscale_swinir +from .blend_img2img import BlendImg2ImgStage +from .blend_inpaint import BlendInpaintStage +from .blend_linear import BlendLinearStage +from .blend_mask import BlendMaskStage +from .correct_codeformer import CorrectCodeformerStage +from .correct_gfpgan import CorrectGFPGANStage +from .persist_disk import PersistDiskStage +from .persist_s3 import PersistS3Stage +from .reduce_crop import ReduceCropStage +from .reduce_thumbnail import ReduceThumbnailStage +from .source_noise import SourceNoiseStage +from .source_s3 import SourceS3Stage +from .source_txt2img import SourceTxt2ImgStage +from .source_url import SourceURLStage +from .upscale_bsrgan import UpscaleBSRGANStage +from .upscale_highres import UpscaleHighresStage +from .upscale_outpaint import UpscaleOutpaintStage +from .upscale_resrgan import UpscaleRealESRGANStage +from .upscale_stable_diffusion import UpscaleStableDiffusionStage +from .upscale_swinir import UpscaleSwinIRStage CHAIN_STAGES = { - "blend-img2img": blend_img2img, - "blend-inpaint": blend_inpaint, - "blend-linear": blend_linear, - "blend-mask": blend_mask, - "correct-codeformer": correct_codeformer, - "correct-gfpgan": correct_gfpgan, - "persist-disk": persist_disk, - "persist-s3": persist_s3, - "reduce-crop": reduce_crop, - "reduce-thumbnail": reduce_thumbnail, - "source-noise": source_noise, - "source-s3": source_s3, - "source-txt2img": source_txt2img, - "source-url": source_url, - "upscale-bsrgan": upscale_bsrgan, - "upscale-highres": upscale_highres, - "upscale-outpaint": upscale_outpaint, - "upscale-resrgan": upscale_resrgan, - "upscale-stable-diffusion": upscale_stable_diffusion, - "upscale-swinir": upscale_swinir, + "blend-img2img": BlendImg2ImgStage, + "blend-inpaint": BlendInpaintStage, + "blend-linear": BlendLinearStage, + "blend-mask": BlendMaskStage, + "correct-codeformer": CorrectCodeformerStage, + "correct-gfpgan": CorrectGFPGANStage, + "persist-disk": PersistDiskStage, + "persist-s3": PersistS3Stage, + "reduce-crop": ReduceCropStage, + "reduce-thumbnail": ReduceThumbnailStage, + "source-noise": SourceNoiseStage, + "source-s3": SourceS3Stage, + "source-txt2img": SourceTxt2ImgStage, + "source-url": SourceURLStage, + "upscale-bsrgan": UpscaleBSRGANStage, + "upscale-highres": UpscaleHighresStage, + "upscale-outpaint": UpscaleOutpaintStage, + "upscale-resrgan": UpscaleRealESRGANStage, + "upscale-stable-diffusion": UpscaleStableDiffusionStage, + "upscale-swinir": UpscaleSwinIRStage, } diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 2b2ecaaa..2a3c0c63 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -10,6 +10,7 @@ from ..params import ImageParams, StageParams from ..server import ServerContext from ..utils import is_debug from ..worker import ProgressCallback, WorkerContext +from .stage import BaseStage from .utils import process_tile_order logger = getLogger(__name__) @@ -35,7 +36,7 @@ class StageCallback(Protocol): pass -PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]] +PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] class ChainProgress: @@ -131,7 +132,7 @@ class ChainPipeline: logger.info("running pipeline without source image") for stage_pipe, stage_params, stage_kwargs in self.stages: - name = stage_params.name or stage_pipe.__name__ + name = stage_params.name or stage_pipe.__class__.__name__ kwargs = stage_kwargs or {} kwargs = {**pipeline_kwargs, **kwargs} @@ -158,7 +159,7 @@ class ChainPipeline: ) def stage_tile(tile: Image.Image, _dims) -> Image.Image: - tile = stage_pipe( + tile = stage_pipe.run( job, server, stage_params, @@ -182,7 +183,7 @@ class ChainPipeline: ) else: logger.debug("image within tile size, running stage") - image = stage_pipe( + image = stage_pipe.run( job, server, stage_params, diff --git a/api/onnx_web/chain/blend_highres.py b/api/onnx_web/chain/blend_highres.py deleted file mode 100644 index e69de29b..00000000 diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 7a6873ad..293f57fd 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -14,77 +14,81 @@ from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) -def blend_img2img( - job: WorkerContext, - server: ServerContext, - _stage: StageParams, - params: ImageParams, - source: Image.Image, - *, - strength: float, - callback: Optional[ProgressCallback] = None, - stage_source: Optional[Image.Image] = None, - **kwargs, -) -> Image.Image: - params = params.with_args(**kwargs) - source = stage_source or source - logger.info( - "blending image using img2img, %s steps: %s", params.steps, params.prompt - ) - - prompt_pairs, loras, inversions = parse_prompt(params) - - pipe_type = params.get_valid_pipeline("img2img") - pipe = load_pipeline( - server, - params, - pipe_type, - job.get_device(), - inversions=inversions, - loras=loras, - ) - - pipe_params = {} - if pipe_type == "controlnet": - pipe_params["controlnet_conditioning_scale"] = strength - elif pipe_type == "img2img": - pipe_params["strength"] = strength - elif pipe_type == "panorama": - pipe_params["strength"] = strength - elif pipe_type == "pix2pix": - pipe_params["image_guidance_scale"] = strength - - if params.lpw(): - logger.debug("using LPW pipeline for img2img") - rng = torch.manual_seed(params.seed) - result = pipe.img2img( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - image=source, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - callback=callback, - **pipe_params, - ) - else: - # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg()) - pipe.unet.set_prompts(prompt_embeds) - - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - image=source, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - callback=callback, - **pipe_params, +class BlendImg2ImgStage: + def run( + self, + job: WorkerContext, + server: ServerContext, + _stage: StageParams, + params: ImageParams, + source: Image.Image, + *, + strength: float, + callback: Optional[ProgressCallback] = None, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + params = params.with_args(**kwargs) + source = stage_source or source + logger.info( + "blending image using img2img, %s steps: %s", params.steps, params.prompt ) - output = result.images[0] + prompt_pairs, loras, inversions = parse_prompt(params) - logger.info("final output image size: %sx%s", output.width, output.height) - return output + pipe_type = params.get_valid_pipeline("img2img") + pipe = load_pipeline( + server, + params, + pipe_type, + job.get_device(), + inversions=inversions, + loras=loras, + ) + + pipe_params = {} + if pipe_type == "controlnet": + pipe_params["controlnet_conditioning_scale"] = strength + elif pipe_type == "img2img": + pipe_params["strength"] = strength + elif pipe_type == "panorama": + pipe_params["strength"] = strength + elif pipe_type == "pix2pix": + pipe_params["image_guidance_scale"] = strength + + if params.lpw(): + logger.debug("using LPW pipeline for img2img") + rng = torch.manual_seed(params.seed) + result = pipe.img2img( + params.prompt, + generator=rng, + guidance_scale=params.cfg, + image=source, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + callback=callback, + **pipe_params, + ) + else: + # encode and record alternative prompts outside of LPW + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) + + rng = np.random.RandomState(params.seed) + result = pipe( + params.prompt, + generator=rng, + guidance_scale=params.cfg, + image=source, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + callback=callback, + **pipe_params, + ) + + output = result.images[0] + + logger.info("final output image size: %sx%s", output.width, output.height) + return output diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 3de0e512..c613bcbe 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -18,105 +18,112 @@ from .utils import process_tile_order logger = getLogger(__name__) -def blend_inpaint( - job: WorkerContext, - server: ServerContext, - stage: StageParams, - params: ImageParams, - source: Image.Image, - *, - expand: Border, - stage_source: Optional[Image.Image] = None, - stage_mask: Optional[Image.Image] = None, - fill_color: str = "white", - mask_filter: Callable = mask_filter_none, - noise_source: Callable = noise_source_histogram, - callback: Optional[ProgressCallback] = None, - **kwargs, -) -> Image.Image: - params = params.with_args(**kwargs) - expand = expand.with_args(**kwargs) - source = source or stage_source - logger.info( - "blending image using inpaint, %s steps: %s", params.steps, params.prompt - ) +class BlendInpaintStage: + def run( + self, + job: WorkerContext, + server: ServerContext, + stage: StageParams, + params: ImageParams, + source: Image.Image, + *, + expand: Border, + stage_source: Optional[Image.Image] = None, + stage_mask: Optional[Image.Image] = None, + fill_color: str = "white", + mask_filter: Callable = mask_filter_none, + noise_source: Callable = noise_source_histogram, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> Image.Image: + params = params.with_args(**kwargs) + expand = expand.with_args(**kwargs) + source = source or stage_source + logger.info( + "blending image using inpaint, %s steps: %s", params.steps, params.prompt + ) - if stage_mask is None: - # if no mask was provided, keep the full source image - stage_mask = Image.new("RGB", source.size, "black") + if stage_mask is None: + # if no mask was provided, keep the full source image + stage_mask = Image.new("RGB", source.size, "black") - source, stage_mask, noise, _full_dims = expand_image( - source, - stage_mask, - expand, - fill=fill_color, - noise_source=noise_source, - mask_filter=mask_filter, - ) - - if is_debug(): - save_image(server, "last-source.png", source) - save_image(server, "last-mask.png", stage_mask) - save_image(server, "last-noise.png", noise) - - pipe_type = "lpw" if params.lpw() else "inpaint" - pipe = load_pipeline( - server, - params, - pipe_type, - job.get_device(), - # TODO: add LoRAs and TIs - ) - - def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): - left, top, tile = dims - size = Size(*tile_source.size) - tile_mask = stage_mask.crop((left, top, left + tile, top + tile)) + source, stage_mask, noise, _full_dims = expand_image( + source, + stage_mask, + expand, + fill=fill_color, + noise_source=noise_source, + mask_filter=mask_filter, + ) if is_debug(): - save_image(server, "tile-source.png", tile_source) - save_image(server, "tile-mask.png", tile_mask) + save_image(server, "last-source.png", source) + save_image(server, "last-mask.png", stage_mask) + save_image(server, "last-noise.png", noise) - latents = get_latents_from_seed(params.seed, size) - if params.lpw(): - logger.debug("using LPW pipeline for inpaint") - rng = torch.manual_seed(params.seed) - result = pipe.inpaint( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - height=size.height, - image=tile_source, - latents=latents, - mask_image=tile_mask, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - width=size.width, - eta=params.eta, - callback=callback, - ) - else: - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - generator=rng, - guidance_scale=params.cfg, - height=size.height, - image=tile_source, - latents=latents, - mask_image=stage_mask, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - width=size.width, - eta=params.eta, - callback=callback, - ) + pipe_type = "lpw" if params.lpw() else "inpaint" + pipe = load_pipeline( + server, + params, + pipe_type, + job.get_device(), + # TODO: add LoRAs and TIs + ) - return result.images[0] + def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): + left, top, tile = dims + size = Size(*tile_source.size) + tile_mask = stage_mask.crop((left, top, left + tile, top + tile)) - output = process_tile_order( - stage.tile_order, source, SizeChart.auto, 1, [outpaint], overlap=params.overlap - ) + if is_debug(): + save_image(server, "tile-source.png", tile_source) + save_image(server, "tile-mask.png", tile_mask) - logger.info("final output image size: %s", output.size) - return output + latents = get_latents_from_seed(params.seed, size) + if params.lpw(): + logger.debug("using LPW pipeline for inpaint") + rng = torch.manual_seed(params.seed) + result = pipe.inpaint( + params.prompt, + generator=rng, + guidance_scale=params.cfg, + height=size.height, + image=tile_source, + latents=latents, + mask_image=tile_mask, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + width=size.width, + eta=params.eta, + callback=callback, + ) + else: + rng = np.random.RandomState(params.seed) + result = pipe( + params.prompt, + generator=rng, + guidance_scale=params.cfg, + height=size.height, + image=tile_source, + latents=latents, + mask_image=stage_mask, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + width=size.width, + eta=params.eta, + callback=callback, + ) + + return result.images[0] + + output = process_tile_order( + stage.tile_order, + source, + SizeChart.auto, + 1, + [outpaint], + overlap=params.overlap, + ) + + logger.info("final output image size: %s", output.size) + return output diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index 69f79429..4873a99c 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -10,17 +10,19 @@ from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) -def blend_linear( - _job: WorkerContext, - _server: ServerContext, - _stage: StageParams, - _params: ImageParams, - *, - alpha: float, - sources: Optional[List[Image.Image]] = None, - _callback: Optional[ProgressCallback] = None, - **kwargs, -) -> Image.Image: - logger.info("blending image using linear interpolation") +class BlendLinearStage: + def run( + self, + _job: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + *, + alpha: float, + sources: Optional[List[Image.Image]] = None, + _callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> Image.Image: + logger.info("blending image using linear interpolation") - return Image.blend(sources[1], sources[0], alpha) + return Image.blend(sources[1], sources[0], alpha) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 99aba6b9..d806fb17 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -12,26 +12,28 @@ from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) -def blend_mask( - _job: WorkerContext, - server: ServerContext, - _stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - stage_source: Optional[Image.Image] = None, - stage_mask: Optional[Image.Image] = None, - _callback: Optional[ProgressCallback] = None, - **kwargs, -) -> Image.Image: - logger.info("blending image using mask") +class BlendMaskStage: + def run( + self, + _job: WorkerContext, + server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + stage_source: Optional[Image.Image] = None, + stage_mask: Optional[Image.Image] = None, + _callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> Image.Image: + logger.info("blending image using mask") - mult_mask = Image.new("RGBA", stage_mask.size, color="black") - mult_mask.alpha_composite(stage_mask) - mult_mask = mult_mask.convert("L") + mult_mask = Image.new("RGBA", stage_mask.size, color="black") + mult_mask.alpha_composite(stage_mask) + mult_mask = mult_mask.convert("L") - if is_debug(): - save_image(server, "last-mask.png", stage_mask) - save_image(server, "last-mult-mask.png", mult_mask) + if is_debug(): + save_image(server, "last-mask.png", stage_mask) + save_image(server, "last-mult-mask.png", mult_mask) - return Image.composite(stage_source, source, mult_mask) + return Image.composite(stage_source, source, mult_mask) diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 60605b43..c950fdfc 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -9,28 +9,28 @@ from ..worker import WorkerContext logger = getLogger(__name__) -device = "cpu" +class CorrectCodeformerStage: + def run( + self, + job: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + stage_source: Optional[Image.Image] = None, + upscale: UpscaleParams, + **kwargs, + ) -> Image.Image: + # must be within the load function for patch to take effect + # TODO: rewrite and remove + from codeformer import CodeFormer -def correct_codeformer( - job: WorkerContext, - _server: ServerContext, - _stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - stage_source: Optional[Image.Image] = None, - upscale: UpscaleParams, - **kwargs, -) -> Image.Image: - # must be within the load function for patch to take effect - # TODO: rewrite and remove - from codeformer import CodeFormer + source = stage_source or source - source = stage_source or source + upscale = upscale.with_args(**kwargs) - upscale = upscale.with_args(**kwargs) - - device = job.get_device() - pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) - return pipe(source) + device = job.get_device() + pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) + return pipe(source) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 4ef224a5..c42dbd7b 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -13,72 +13,74 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def load_gfpgan( - server: ServerContext, - _stage: StageParams, - upscale: UpscaleParams, - device: DeviceParams, -): - # must be within the load function for patch to take effect - # TODO: rewrite and remove - from gfpgan import GFPGANer +class CorrectGFPGANStage: + def load( + self, + server: ServerContext, + _stage: StageParams, + upscale: UpscaleParams, + device: DeviceParams, + ): + # must be within the load function for patch to take effect + # TODO: rewrite and remove + from gfpgan import GFPGANer - face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model)) - cache_key = (face_path,) - cache_pipe = server.cache.get("gfpgan", cache_key) + face_path = path.join(server.cache_path, "%s.pth" % (upscale.correction_model)) + cache_key = (face_path,) + cache_pipe = server.cache.get("gfpgan", cache_key) - if cache_pipe is not None: - logger.info("reusing existing GFPGAN pipeline") - return cache_pipe + if cache_pipe is not None: + logger.info("reusing existing GFPGAN pipeline") + return cache_pipe - logger.debug("loading GFPGAN model from %s", face_path) + logger.debug("loading GFPGAN model from %s", face_path) - # TODO: find a way to pass the ONNX model to underlying architectures - gfpgan = GFPGANer( - arch="clean", - bg_upsampler=None, - channel_multiplier=2, - device=device.torch_str(), - model_path=face_path, - upscale=upscale.face_outscale, - ) + # TODO: find a way to pass the ONNX model to underlying architectures + gfpgan = GFPGANer( + arch="clean", + bg_upsampler=None, + channel_multiplier=2, + device=device.torch_str(), + model_path=face_path, + upscale=upscale.face_outscale, + ) - server.cache.set("gfpgan", cache_key, gfpgan) - run_gc([device]) + server.cache.set("gfpgan", cache_key, gfpgan) + run_gc([device]) - return gfpgan + return gfpgan + def run( + self, + job: WorkerContext, + server: ServerContext, + stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + upscale: UpscaleParams, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + upscale = upscale.with_args(**kwargs) + source = stage_source or source -def correct_gfpgan( - job: WorkerContext, - server: ServerContext, - stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - upscale: UpscaleParams, - stage_source: Optional[Image.Image] = None, - **kwargs, -) -> Image.Image: - upscale = upscale.with_args(**kwargs) - source = stage_source or source + if upscale.correction_model is None: + logger.warn("no face model given, skipping") + return source - if upscale.correction_model is None: - logger.warn("no face model given, skipping") - return source + logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) + device = job.get_device() + gfpgan = self.load(server, stage, upscale, device) - logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) - device = job.get_device() - gfpgan = load_gfpgan(server, stage, upscale, device) + output = np.array(source) + _, _, output = gfpgan.enhance( + output, + has_aligned=False, + only_center_face=False, + paste_back=True, + weight=upscale.face_strength, + ) + output = Image.fromarray(output, "RGB") - output = np.array(source) - _, _, output = gfpgan.enhance( - output, - has_aligned=False, - only_center_face=False, - paste_back=True, - weight=upscale.face_strength, - ) - output = Image.fromarray(output, "RGB") - - return output + return output diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index d5cecdf7..ea7abadb 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -10,19 +10,21 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def persist_disk( - _job: WorkerContext, - server: ServerContext, - _stage: StageParams, - params: ImageParams, - source: Image.Image, - *, - output: str, - stage_source: Image.Image, - **kwargs, -) -> Image.Image: - source = stage_source or source +class PersistDiskStage: + def run( + self, + _job: WorkerContext, + server: ServerContext, + _stage: StageParams, + params: ImageParams, + source: Image.Image, + *, + output: str, + stage_source: Image.Image, + **kwargs, + ) -> Image.Image: + source = stage_source or source - dest = save_image(server, output, source, params=params) - logger.info("saved image to %s", dest) - return source + dest = save_image(server, output, source, params=params) + logger.info("saved image to %s", dest) + return source diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index 3b755f53..c0916788 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -12,33 +12,35 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def persist_s3( - _job: WorkerContext, - server: ServerContext, - _stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - output: str, - bucket: str, - endpoint_url: Optional[str] = None, - profile_name: Optional[str] = None, - stage_source: Optional[Image.Image] = None, - **kwargs, -) -> Image.Image: - source = stage_source or source +class PersistS3Stage: + def run( + self, + _job: WorkerContext, + server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + output: str, + bucket: str, + endpoint_url: Optional[str] = None, + profile_name: Optional[str] = None, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + source = stage_source or source - session = Session(profile_name=profile_name) - s3 = session.client("s3", endpoint_url=endpoint_url) + session = Session(profile_name=profile_name) + s3 = session.client("s3", endpoint_url=endpoint_url) - data = BytesIO() - source.save(data, format=server.image_format) - data.seek(0) + data = BytesIO() + source.save(data, format=server.image_format) + data.seek(0) - try: - s3.upload_fileobj(data, bucket, output) - logger.info("saved image to s3://%s/%s", bucket, output) - except Exception: - logger.exception("error saving image to S3") + try: + s3.upload_fileobj(data, bucket, output) + logger.info("saved image to s3://%s/%s", bucket, output) + except Exception: + logger.exception("error saving image to S3") - return source + return source diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 3f2d82db..8e88433a 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -10,20 +10,24 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def reduce_crop( - _job: WorkerContext, - _server: ServerContext, - _stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - origin: Size, - size: Size, - stage_source: Optional[Image.Image] = None, - **kwargs, -) -> Image.Image: - source = stage_source or source +class ReduceCropStage: + def run( + self, + _job: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + origin: Size, + size: Size, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + source = stage_source or source - image = source.crop((origin.width, origin.height, size.width, size.height)) - logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) - return image + image = source.crop((origin.width, origin.height, size.width, size.height)) + logger.info( + "created thumbnail with dimensions: %sx%s", image.width, image.height + ) + return image diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index c51cda7a..6bc7b9d6 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -9,21 +9,25 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def reduce_thumbnail( - _job: WorkerContext, - _server: ServerContext, - _stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - size: Size, - stage_source: Image.Image, - **kwargs, -) -> Image.Image: - source = stage_source or source - image = source.copy() +class ReduceThumbnailStage: + def run( + self, + _job: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + size: Size, + stage_source: Image.Image, + **kwargs, + ) -> Image.Image: + source = stage_source or source + image = source.copy() - image = image.thumbnail((size.width, size.height)) + image = image.thumbnail((size.width, size.height)) - logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) - return image + logger.info( + "created thumbnail with dimensions: %sx%s", image.width, image.height + ) + return image diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 0092292c..71b80f1a 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -10,25 +10,29 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def source_noise( - _job: WorkerContext, - _server: ServerContext, - _stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - size: Size, - noise_source: Callable, - stage_source: Image.Image, - **kwargs, -) -> Image.Image: - source = stage_source or source - logger.info("generating image from noise source") +class SourceNoiseStage: + def run( + self, + _job: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + size: Size, + noise_source: Callable, + stage_source: Image.Image, + **kwargs, + ) -> Image.Image: + source = stage_source or source + logger.info("generating image from noise source") - if source is not None: - logger.warn("a source image was passed to a noise stage, but will be discarded") + if source is not None: + logger.warn( + "a source image was passed to a noise stage, but will be discarded" + ) - output = noise_source(source, (size.width, size.height), (0, 0)) + output = noise_source(source, (size.width, size.height), (0, 0)) - logger.info("final output image size: %sx%s", output.width, output.height) - return output + logger.info("final output image size: %sx%s", output.width, output.height) + return output diff --git a/api/onnx_web/chain/source_s3.py b/api/onnx_web/chain/source_s3.py index 7c4b4158..8aee6232 100644 --- a/api/onnx_web/chain/source_s3.py +++ b/api/onnx_web/chain/source_s3.py @@ -12,31 +12,33 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def source_s3( - _job: WorkerContext, - server: ServerContext, - _stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - source_key: str, - bucket: str, - endpoint_url: Optional[str] = None, - profile_name: Optional[str] = None, - stage_source: Optional[Image.Image] = None, - **kwargs, -) -> Image.Image: - source = stage_source or source +class SourceS3Stage: + def run( + self, + _job: WorkerContext, + server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + source_key: str, + bucket: str, + endpoint_url: Optional[str] = None, + profile_name: Optional[str] = None, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + source = stage_source or source - session = Session(profile_name=profile_name) - s3 = session.client("s3", endpoint_url=endpoint_url) + session = Session(profile_name=profile_name) + s3 = session.client("s3", endpoint_url=endpoint_url) - try: - logger.info("loading image from s3://%s/%s", bucket, source_key) - data = BytesIO() - s3.download_fileobj(bucket, source_key, data) + try: + logger.info("loading image from s3://%s/%s", bucket, source_key) + data = BytesIO() + s3.download_fileobj(bucket, source_key, data) - data.seek(0) - return Image.open(data) - except Exception: - logger.exception("error loading image from S3") + data.seek(0) + return Image.open(data) + except Exception: + logger.exception("error loading image from S3") diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 5e681f71..5721d27a 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -14,74 +14,78 @@ from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) -def source_txt2img( - job: WorkerContext, - server: ServerContext, - _stage: StageParams, - params: ImageParams, - _source: Image.Image, - *, - size: Size, - callback: Optional[ProgressCallback] = None, - **kwargs, -) -> Image.Image: - params = params.with_args(**kwargs) - size = size.with_args(**kwargs) - logger.info( - "generating image using txt2img, %s steps: %s", params.steps, params.prompt - ) - - if "stage_source" in kwargs: - logger.warn( - "a source image was passed to a txt2img stage, and will be discarded" +class SourceTxt2ImgStage: + def run( + self, + job: WorkerContext, + server: ServerContext, + _stage: StageParams, + params: ImageParams, + _source: Image.Image, + *, + size: Size, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> Image.Image: + params = params.with_args(**kwargs) + size = size.with_args(**kwargs) + logger.info( + "generating image using txt2img, %s steps: %s", params.steps, params.prompt ) - prompt_pairs, loras, inversions = parse_prompt(params) + if "stage_source" in kwargs: + logger.warn( + "a source image was passed to a txt2img stage, and will be discarded" + ) - latents = get_latents_from_seed(params.seed, size) - pipe_type = params.get_valid_pipeline("txt2img") - pipe = load_pipeline( - server, - params, - pipe_type, - job.get_device(), - inversions=inversions, - loras=loras, - ) + prompt_pairs, loras, inversions = parse_prompt(params) - if params.lpw(): - logger.debug("using LPW pipeline for txt2img") - rng = torch.manual_seed(params.seed) - result = pipe.text2img( - params.prompt, - height=size.height, - width=size.width, - generator=rng, - guidance_scale=params.cfg, - latents=latents, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - callback=callback, - ) - else: - # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt(pipe, prompt_pairs, params.batch, params.do_cfg()) - pipe.unet.set_prompts(prompt_embeds) - - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - height=size.height, - width=size.width, - generator=rng, - guidance_scale=params.cfg, - latents=latents, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - callback=callback, + latents = get_latents_from_seed(params.seed, size) + pipe_type = params.get_valid_pipeline("txt2img") + pipe = load_pipeline( + server, + params, + pipe_type, + job.get_device(), + inversions=inversions, + loras=loras, ) - output = result.images[0] + if params.lpw(): + logger.debug("using LPW pipeline for txt2img") + rng = torch.manual_seed(params.seed) + result = pipe.text2img( + params.prompt, + height=size.height, + width=size.width, + generator=rng, + guidance_scale=params.cfg, + latents=latents, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + callback=callback, + ) + else: + # encode and record alternative prompts outside of LPW + prompt_embeds = encode_prompt( + pipe, prompt_pairs, params.batch, params.do_cfg() + ) + pipe.unet.set_prompts(prompt_embeds) - logger.info("final output image size: %sx%s", output.width, output.height) - return output + rng = np.random.RandomState(params.seed) + result = pipe( + params.prompt, + height=size.height, + width=size.width, + generator=rng, + guidance_scale=params.cfg, + latents=latents, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + callback=callback, + ) + + output = result.images[0] + + logger.info("final output image size: %sx%s", output.width, output.height) + return output diff --git a/api/onnx_web/chain/source_url.py b/api/onnx_web/chain/source_url.py index 3496be95..233d877f 100644 --- a/api/onnx_web/chain/source_url.py +++ b/api/onnx_web/chain/source_url.py @@ -11,27 +11,29 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def source_url( - _job: WorkerContext, - _server: ServerContext, - _stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - source_url: str, - stage_source: Image.Image, - **kwargs, -) -> Image.Image: - source = stage_source or source - logger.info("loading image from URL source") +class SourceURLStage: + def run( + self, + _job: WorkerContext, + _server: ServerContext, + _stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + source_url: str, + stage_source: Image.Image, + **kwargs, + ) -> Image.Image: + source = stage_source or source + logger.info("loading image from URL source") - if source is not None: - logger.warn( - "a source image was passed to a source stage, and will be discarded" - ) + if source is not None: + logger.warn( + "a source image was passed to a source stage, and will be discarded" + ) - response = requests.get(source_url) - output = Image.open(BytesIO(response.content)) + response = requests.get(source_url) + output = Image.open(BytesIO(response.content)) - logger.info("final output image size: %sx%s", output.width, output.height) - return output + logger.info("final output image size: %sx%s", output.width, output.height) + return output diff --git a/api/onnx_web/chain/stage.py b/api/onnx_web/chain/stage.py new file mode 100644 index 00000000..7e43cd59 --- /dev/null +++ b/api/onnx_web/chain/stage.py @@ -0,0 +1,31 @@ +from typing import Optional + +from PIL import Image + +from onnx_web.params import ImageParams, Size, SizeChart, StageParams +from onnx_web.server.context import ServerContext +from onnx_web.worker.context import WorkerContext + + +class BaseStage: + max_tile = SizeChart.auto + + def run( + self, + job: WorkerContext, + server: ServerContext, + stage: StageParams, + _params: ImageParams, + source: Image.Image, + *args, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + raise NotImplementedError() + + def steps( + self, + _params: ImageParams, + size: Size, + ) -> int: + raise NotImplementedError() diff --git a/api/onnx_web/diffusers/upscale.py b/api/onnx_web/chain/upscale.py similarity index 80% rename from api/onnx_web/diffusers/upscale.py rename to api/onnx_web/chain/upscale.py index 42973951..b53f98b3 100644 --- a/api/onnx_web/diffusers/upscale.py +++ b/api/onnx_web/chain/upscale.py @@ -1,14 +1,14 @@ from logging import getLogger from typing import List, Optional, Tuple -from ..chain import ChainPipeline, PipelineStage -from ..chain.correct_codeformer import correct_codeformer -from ..chain.correct_gfpgan import correct_gfpgan -from ..chain.upscale_bsrgan import upscale_bsrgan -from ..chain.upscale_resrgan import upscale_resrgan -from ..chain.upscale_stable_diffusion import upscale_stable_diffusion -from ..chain.upscale_swinir import upscale_swinir from ..params import ImageParams, SizeChart, StageParams, UpscaleParams +from . import ChainPipeline, PipelineStage +from .correct_codeformer import CorrectCodeformerStage +from .correct_gfpgan import CorrectGFPGANStage +from .upscale_bsrgan import UpscaleBSRGANStage +from .upscale_resrgan import UpscaleRealESRGANStage +from .upscale_stable_diffusion import UpscaleStableDiffusionStage +from .upscale_swinir import UpscaleSwinIRStage logger = getLogger(__name__) @@ -72,23 +72,23 @@ def stage_upscale_correction( tile_size=stage.tile_size, outscale=upscale.outscale, ) - upscale_stage = (upscale_bsrgan, bsrgan_params, upscale_opts) + upscale_stage = (UpscaleBSRGANStage(), bsrgan_params, upscale_opts) elif "esrgan" in upscale.upscale_model: esrgan_params = StageParams( tile_size=stage.tile_size, outscale=upscale.outscale, ) - upscale_stage = (upscale_resrgan, esrgan_params, upscale_opts) + upscale_stage = (UpscaleRealESRGANStage(), esrgan_params, upscale_opts) elif "stable-diffusion" in upscale.upscale_model: mini_tile = min(SizeChart.mini, stage.tile_size) sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) - upscale_stage = (upscale_stable_diffusion, sd_params, upscale_opts) + upscale_stage = (UpscaleStableDiffusionStage(), sd_params, upscale_opts) elif "swinir" in upscale.upscale_model: swinir_params = StageParams( tile_size=stage.tile_size, outscale=upscale.outscale, ) - upscale_stage = (upscale_swinir, swinir_params, upscale_opts) + upscale_stage = (UpscaleSwinIRStage(), swinir_params, upscale_opts) else: logger.warn("unknown upscaling model: %s", upscale.upscale_model) @@ -98,9 +98,9 @@ def stage_upscale_correction( tile_size=stage.tile_size, outscale=upscale.face_outscale ) if "codeformer" in upscale.correction_model: - correct_stage = (correct_codeformer, face_params, upscale_opts) + correct_stage = (CorrectCodeformerStage(), face_params, upscale_opts) elif "gfpgan" in upscale.correction_model: - correct_stage = (correct_gfpgan, face_params, upscale_opts) + correct_stage = (CorrectGFPGANStage(), face_params, upscale_opts) else: logger.warn("unknown correction model: %s", upscale.correction_model) diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index 37350d09..78faa2c5 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -6,7 +6,7 @@ import numpy as np from PIL import Image from ..models.onnx import OnnxModel -from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams +from ..params import DeviceParams, ImageParams, Size, StageParams, UpscaleParams from ..server import ServerContext from ..utils import run_gc from ..worker import WorkerContext @@ -14,105 +14,121 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def load_bsrgan( - server: ServerContext, - _stage: StageParams, - upscale: UpscaleParams, - device: DeviceParams, -): - # must be within the load function for patch to take effect - model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) - cache_key = (model_path,) - cache_pipe = server.cache.get("bsrgan", cache_key) +class UpscaleBSRGANStage: + max_tile = 64 - if cache_pipe is not None: - logger.debug("reusing existing BSRGAN pipeline") - return cache_pipe + def load( + self, + server: ServerContext, + _stage: StageParams, + upscale: UpscaleParams, + device: DeviceParams, + ): + # must be within the load function for patch to take effect + model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) + cache_key = (model_path,) + cache_pipe = server.cache.get("bsrgan", cache_key) - logger.info("loading BSRGAN model from %s", model_path) + if cache_pipe is not None: + logger.debug("reusing existing BSRGAN pipeline") + return cache_pipe - pipe = OnnxModel( - server, - model_path, - provider=device.ort_provider(), - sess_options=device.sess_options(), - ) + logger.info("loading BSRGAN model from %s", model_path) - server.cache.set("bsrgan", cache_key, pipe) - run_gc([device]) + pipe = OnnxModel( + server, + model_path, + provider=device.ort_provider(), + sess_options=device.sess_options(), + ) - return pipe + server.cache.set("bsrgan", cache_key, pipe) + run_gc([device]) + return pipe -def upscale_bsrgan( - job: WorkerContext, - server: ServerContext, - stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - upscale: UpscaleParams, - stage_source: Optional[Image.Image] = None, - **kwargs, -) -> Image.Image: - upscale = upscale.with_args(**kwargs) - source = stage_source or source + def run( + self, + job: WorkerContext, + server: ServerContext, + stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + upscale: UpscaleParams, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + upscale = upscale.with_args(**kwargs) + source = stage_source or source - if upscale.upscale_model is None: - logger.warn("no upscaling model given, skipping") - return source + if upscale.upscale_model is None: + logger.warn("no upscaling model given, skipping") + return source - logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model) - device = job.get_device() - bsrgan = load_bsrgan(server, stage, upscale, device) + logger.info("upscaling with BSRGAN model: %s", upscale.upscale_model) + device = job.get_device() + bsrgan = self.load(server, stage, upscale, device) - tile_size = (64, 64) - tile_x = source.width // tile_size[0] - tile_y = source.height // tile_size[1] + tile_size = (64, 64) + tile_x = source.width // tile_size[0] + tile_y = source.height // tile_size[1] - image = np.array(source) / 255.0 - image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) - image = np.expand_dims(image, axis=0) - logger.trace("BSRGAN input shape: %s", image.shape) + image = np.array(source) / 255.0 + image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) + image = np.expand_dims(image, axis=0) + logger.trace("BSRGAN input shape: %s", image.shape) - scale = upscale.outscale - dest = np.zeros( - (image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale) - ) - logger.trace("BSRGAN output shape: %s", dest.shape) - - for x in range(tile_x): - for y in range(tile_y): - xt = x * tile_size[0] - yt = y * tile_size[1] - - ix1 = xt - ix2 = xt + tile_size[0] - iy1 = yt - iy2 = yt + tile_size[1] - logger.debug( - "running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)", - ix1, - ix2, - iy1, - iy2, - ix1 * scale, - ix2 * scale, - iy1 * scale, - iy2 * scale, + scale = upscale.outscale + dest = np.zeros( + ( + image.shape[0], + image.shape[1], + image.shape[2] * scale, + image.shape[3] * scale, ) + ) + logger.trace("BSRGAN output shape: %s", dest.shape) - dest[ - :, - :, - ix1 * scale : ix2 * scale, - iy1 * scale : iy2 * scale, - ] = bsrgan(image[:, :, ix1:ix2, iy1:iy2]) + for x in range(tile_x): + for y in range(tile_y): + xt = x * tile_size[0] + yt = y * tile_size[1] - dest = np.clip(np.squeeze(dest, axis=0), 0, 1) - dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) - dest = (dest * 255.0).round().astype(np.uint8) + ix1 = xt + ix2 = xt + tile_size[0] + iy1 = yt + iy2 = yt + tile_size[1] + logger.debug( + "running BSRGAN on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)", + ix1, + ix2, + iy1, + iy2, + ix1 * scale, + ix2 * scale, + iy1 * scale, + iy2 * scale, + ) - output = Image.fromarray(dest, "RGB") - logger.debug("output image size: %s x %s", output.width, output.height) - return output + dest[ + :, + :, + ix1 * scale : ix2 * scale, + iy1 * scale : iy2 * scale, + ] = bsrgan(image[:, :, ix1:ix2, iy1:iy2]) + + dest = np.clip(np.squeeze(dest, axis=0), 0, 1) + dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) + dest = (dest * 255.0).round().astype(np.uint8) + + output = Image.fromarray(dest, "RGB") + logger.debug("output image size: %s x %s", output.width, output.height) + return output + + def steps( + self, + _params: ImageParams, + size: Size, + ) -> int: + return size.width // self.max_tile * size.height // self.max_tile diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index acfe9452..0336feb8 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -3,70 +3,71 @@ from typing import Any, Optional from PIL import Image -from ..chain.base import ChainPipeline -from ..chain.blend_img2img import blend_img2img -from ..diffusers.upscale import stage_upscale_correction +from ..chain import BlendImg2ImgStage, ChainPipeline from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext from ..worker.context import ProgressCallback +from .upscale import stage_upscale_correction logger = getLogger(__name__) -def upscale_highres( - job: WorkerContext, - server: ServerContext, - _stage: StageParams, - params: ImageParams, - source: Image.Image, - *, - highres: HighresParams, - upscale: UpscaleParams, - stage_source: Optional[Image.Image] = None, - pipeline: Optional[Any] = None, - callback: Optional[ProgressCallback] = None, - **kwargs, -) -> Image.Image: - source = stage_source or source +class UpscaleHighresStage: + def run( + self, + job: WorkerContext, + server: ServerContext, + _stage: StageParams, + params: ImageParams, + source: Image.Image, + *, + highres: HighresParams, + upscale: UpscaleParams, + stage_source: Optional[Image.Image] = None, + pipeline: Optional[Any] = None, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> Image.Image: + source = stage_source or source - if highres.scale <= 1: - return source + if highres.scale <= 1: + return source - chain = ChainPipeline() - scaled_size = (source.width * highres.scale, source.height * highres.scale) + chain = ChainPipeline() + scaled_size = (source.width * highres.scale, source.height * highres.scale) - # TODO: upscaling within the same stage prevents tiling from happening and causes OOM - if highres.method == "bilinear": - logger.debug("using bilinear interpolation for highres") - source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) - elif highres.method == "lanczos": - logger.debug("using Lanczos interpolation for highres") - source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) - else: - logger.debug("using upscaling pipeline for highres") - stage_upscale_correction( + # TODO: upscaling within the same stage prevents tiling from happening and causes OOM + if highres.method == "bilinear": + logger.debug("using bilinear interpolation for highres") + source = source.resize(scaled_size, resample=Image.Resampling.BILINEAR) + elif highres.method == "lanczos": + logger.debug("using Lanczos interpolation for highres") + source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) + else: + logger.debug("using upscaling pipeline for highres") + stage_upscale_correction( + StageParams(), + params, + upscale=upscale.with_args( + faces=False, + scale=highres.scale, + outscale=highres.scale, + ), + chain=chain, + ) + + chain.stage( + BlendImg2ImgStage(), StageParams(), - params, - upscale=upscale.with_args( - faces=False, - scale=highres.scale, - outscale=highres.scale, - ), - chain=chain, + overlap=params.overlap, + strength=highres.strength, ) - chain.stage( - blend_img2img, - StageParams(), - overlap=params.overlap, - strength=highres.strength, - ) - - return chain( - job, - server, - params, - source, - callback=callback, - ) + return chain( + job, + server, + params, + source, + callback=callback, + ) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 4db1a3ef..9b94e582 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -18,138 +18,140 @@ from .utils import complete_tile, process_tile_grid, process_tile_order logger = getLogger(__name__) -def upscale_outpaint( - job: WorkerContext, - server: ServerContext, - stage: StageParams, - params: ImageParams, - source: Image.Image, - *, - border: Border, - stage_source: Optional[Image.Image] = None, - stage_mask: Optional[Image.Image] = None, - fill_color: str = "white", - mask_filter: Callable = mask_filter_none, - noise_source: Callable = noise_source_histogram, - callback: Optional[ProgressCallback] = None, - **kwargs, -) -> Image.Image: - source = stage_source or source - logger.info( - "upscaling %s x %s image by expanding borders: %s", - source.width, - source.height, - border, - ) +class UpscaleOutpaintStage: + def run( + self, + job: WorkerContext, + server: ServerContext, + stage: StageParams, + params: ImageParams, + source: Image.Image, + *, + border: Border, + stage_source: Optional[Image.Image] = None, + stage_mask: Optional[Image.Image] = None, + fill_color: str = "white", + mask_filter: Callable = mask_filter_none, + noise_source: Callable = noise_source_histogram, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> Image.Image: + source = stage_source or source + logger.info( + "upscaling %s x %s image by expanding borders: %s", + source.width, + source.height, + border, + ) - margin_x = float(max(border.left, border.right)) - margin_y = float(max(border.top, border.bottom)) - overlap = min(margin_x / source.width, margin_y / source.height) + margin_x = float(max(border.left, border.right)) + margin_y = float(max(border.top, border.bottom)) + overlap = min(margin_x / source.width, margin_y / source.height) - if stage_mask is None: - # if no mask was provided, keep the full source image - stage_mask = Image.new("RGB", source.size, "black") + if stage_mask is None: + # if no mask was provided, keep the full source image + stage_mask = Image.new("RGB", source.size, "black") - # masks start as 512x512, resize to cover the source, then trim the extra - mask_max = max(source.width, source.height) - stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max)) - stage_mask = stage_mask.crop((0, 0, source.width, source.height)) + # masks start as 512x512, resize to cover the source, then trim the extra + mask_max = max(source.width, source.height) + stage_mask = ImageOps.contain(stage_mask, (mask_max, mask_max)) + stage_mask = stage_mask.crop((0, 0, source.width, source.height)) - source, stage_mask, noise, full_size = expand_image( - source, - stage_mask, - border, - fill=fill_color, - noise_source=noise_source, - mask_filter=mask_filter, - ) + source, stage_mask, noise, full_size = expand_image( + source, + stage_mask, + border, + fill=fill_color, + noise_source=noise_source, + mask_filter=mask_filter, + ) - full_latents = get_latents_from_seed(params.seed, Size(*full_size)) + full_latents = get_latents_from_seed(params.seed, Size(*full_size)) - draw_mask = ImageDraw.Draw(stage_mask) - - if is_debug(): - save_image(server, "last-source.png", source) - save_image(server, "last-mask.png", stage_mask) - save_image(server, "last-noise.png", noise) - - def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): - left, top, tile = dims - size = Size(*tile_source.size) - tile_mask = stage_mask.crop((left, top, left + tile, top + tile)) - tile_mask = complete_tile(tile_mask, tile) + draw_mask = ImageDraw.Draw(stage_mask) if is_debug(): - save_image(server, "tile-source.png", tile_source) - save_image(server, "tile-mask.png", tile_mask) + save_image(server, "last-source.png", source) + save_image(server, "last-mask.png", stage_mask) + save_image(server, "last-noise.png", noise) - latents = get_tile_latents(full_latents, dims, size) - pipe_type = params.get_valid_pipeline("inpaint", params.pipeline) - pipe = load_pipeline( - server, - params, - pipe_type, - job.get_device(), - # TODO: load LoRAs and TIs - ) - if params.lpw(): - logger.debug("using LPW pipeline for inpaint") - rng = torch.manual_seed(params.seed) - result = pipe.inpaint( - tile_source, - tile_mask, - params.prompt, - generator=rng, - guidance_scale=params.cfg, - height=size.height, - latents=latents, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - width=size.width, - callback=callback, + def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): + left, top, tile = dims + size = Size(*tile_source.size) + tile_mask = stage_mask.crop((left, top, left + tile, top + tile)) + tile_mask = complete_tile(tile_mask, tile) + + if is_debug(): + save_image(server, "tile-source.png", tile_source) + save_image(server, "tile-mask.png", tile_mask) + + latents = get_tile_latents(full_latents, dims, size) + pipe_type = params.get_valid_pipeline("inpaint", params.pipeline) + pipe = load_pipeline( + server, + params, + pipe_type, + job.get_device(), + # TODO: load LoRAs and TIs + ) + if params.lpw(): + logger.debug("using LPW pipeline for inpaint") + rng = torch.manual_seed(params.seed) + result = pipe.inpaint( + tile_source, + tile_mask, + params.prompt, + generator=rng, + guidance_scale=params.cfg, + height=size.height, + latents=latents, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + width=size.width, + callback=callback, + ) + else: + rng = np.random.RandomState(params.seed) + result = pipe( + params.prompt, + tile_source, + tile_mask, + height=size.height, + width=size.width, + num_inference_steps=params.steps, + guidance_scale=params.cfg, + negative_prompt=params.negative_prompt, + generator=rng, + latents=latents, + callback=callback, + ) + + # once part of the image has been drawn, keep it + draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") + return result.images[0] + + if params.pipeline == "panorama": + logger.debug("outpainting with one shot panorama, no tiling") + return outpaint(source, (0, 0, max(source.width, source.height))) + if overlap == 0: + logger.debug("outpainting with 0 margin, using grid tiling") + output = process_tile_grid(source, SizeChart.auto, 1, [outpaint]) + elif border.left == border.right and border.top == border.bottom: + logger.debug( + "outpainting with an even border, using spiral tiling with %s overlap", + overlap, + ) + output = process_tile_order( + stage.tile_order, + source, + SizeChart.auto, + 1, + [outpaint], + overlap=overlap, ) else: - rng = np.random.RandomState(params.seed) - result = pipe( - params.prompt, - tile_source, - tile_mask, - height=size.height, - width=size.width, - num_inference_steps=params.steps, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - generator=rng, - latents=latents, - callback=callback, - ) + logger.debug("outpainting with an uneven border, using grid tiling") + output = process_tile_grid(source, SizeChart.auto, 1, [outpaint]) - # once part of the image has been drawn, keep it - draw_mask.rectangle((left, top, left + tile, top + tile), fill="black") - return result.images[0] - - if params.pipeline == "panorama": - logger.debug("outpainting with one shot panorama, no tiling") - return outpaint(source, (0, 0, max(source.width, source.height))) - if overlap == 0: - logger.debug("outpainting with 0 margin, using grid tiling") - output = process_tile_grid(source, SizeChart.auto, 1, [outpaint]) - elif border.left == border.right and border.top == border.bottom: - logger.debug( - "outpainting with an even border, using spiral tiling with %s overlap", - overlap, - ) - output = process_tile_order( - stage.tile_order, - source, - SizeChart.auto, - 1, - [outpaint], - overlap=overlap, - ) - else: - logger.debug("outpainting with an uneven border, using grid tiling") - output = process_tile_grid(source, SizeChart.auto, 1, [outpaint]) - - logger.info("final output image size: %sx%s", output.width, output.height) - return output + logger.info("final output image size: %sx%s", output.width, output.height) + return output diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index f3e68ced..b4b99bde 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -16,79 +16,80 @@ logger = getLogger(__name__) TAG_X4_V3 = "real-esrgan-x4-v3" -def load_resrgan( - server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 -): - # must be within load function for patches to take effect - # TODO: rewrite and remove - from realesrgan import RealESRGANer +class UpscaleRealESRGANStage: + def load( + self, server: ServerContext, params: UpscaleParams, device: DeviceParams, tile=0 + ): + # must be within load function for patches to take effect + # TODO: rewrite and remove + from realesrgan import RealESRGANer - model_file = "%s.%s" % (params.upscale_model, params.format) - model_path = path.join(server.model_path, model_file) + model_file = "%s.%s" % (params.upscale_model, params.format) + model_path = path.join(server.model_path, model_file) - cache_key = (model_path, params.format) - cache_pipe = server.cache.get("resrgan", cache_key) - if cache_pipe is not None: - logger.info("reusing existing Real ESRGAN pipeline") - return cache_pipe + cache_key = (model_path, params.format) + cache_pipe = server.cache.get("resrgan", cache_key) + if cache_pipe is not None: + logger.info("reusing existing Real ESRGAN pipeline") + return cache_pipe - if not path.isfile(model_path): - raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path) + if not path.isfile(model_path): + raise FileNotFoundError("Real ESRGAN model not found at %s" % model_path) - # TODO: swap for regular RRDBNet after rewriting wrapper - model = OnnxRRDBNet( - server, - model_file, - provider=device.ort_provider(), - sess_options=device.sess_options(), - ) + # TODO: swap for regular RRDBNet after rewriting wrapper + model = OnnxRRDBNet( + server, + model_file, + provider=device.ort_provider(), + sess_options=device.sess_options(), + ) - dni_weight = None - if params.upscale_model == TAG_X4_V3 and params.denoise != 1: - wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3) - model_path = [model_path, wdn_model_path] - dni_weight = [params.denoise, 1 - params.denoise] + dni_weight = None + if params.upscale_model == TAG_X4_V3 and params.denoise != 1: + wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3) + model_path = [model_path, wdn_model_path] + dni_weight = [params.denoise, 1 - params.denoise] - logger.debug("loading Real ESRGAN upscale model from %s", model_path) + logger.debug("loading Real ESRGAN upscale model from %s", model_path) - # TODO: shouldn't need the PTH file - model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model)) - upsampler = RealESRGANer( - scale=params.scale, - model_path=model_path_pth, - dni_weight=dni_weight, - model=model, - tile=tile, - tile_pad=params.tile_pad, - pre_pad=params.pre_pad, - half=False, # TODO: use server optimizations - ) + # TODO: shouldn't need the PTH file + model_path_pth = path.join(server.cache_path, ("%s.pth" % params.upscale_model)) + upsampler = RealESRGANer( + scale=params.scale, + model_path=model_path_pth, + dni_weight=dni_weight, + model=model, + tile=tile, + tile_pad=params.tile_pad, + pre_pad=params.pre_pad, + half=False, # TODO: use server optimizations + ) - server.cache.set("resrgan", cache_key, upsampler) - run_gc([device]) + server.cache.set("resrgan", cache_key, upsampler) + run_gc([device]) - return upsampler + return upsampler + def run( + self, + job: WorkerContext, + server: ServerContext, + stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + upscale: UpscaleParams, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + source = stage_source or source + logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) -def upscale_resrgan( - job: WorkerContext, - server: ServerContext, - stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - upscale: UpscaleParams, - stage_source: Optional[Image.Image] = None, - **kwargs, -) -> Image.Image: - source = stage_source or source - logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) + output = np.array(source) + upsampler = self.load(server, upscale, job.get_device(), tile=stage.tile_size) - output = np.array(source) - upsampler = load_resrgan(server, upscale, job.get_device(), tile=stage.tile_size) + output, _ = upsampler.enhance(output, outscale=upscale.outscale) - output, _ = upsampler.enhance(output, outscale=upscale.outscale) - - output = Image.fromarray(output, "RGB") - logger.info("final output image size: %sx%s", output.width, output.height) - return output + output = Image.fromarray(output, "RGB") + logger.info("final output image size: %sx%s", output.width, output.height) + return output diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index f5cd42ee..013261e2 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -14,52 +14,54 @@ from ..worker import ProgressCallback, WorkerContext logger = getLogger(__name__) -def upscale_stable_diffusion( - job: WorkerContext, - server: ServerContext, - _stage: StageParams, - params: ImageParams, - source: Image.Image, - *, - upscale: UpscaleParams, - stage_source: Optional[Image.Image] = None, - callback: Optional[ProgressCallback] = None, - **kwargs, -) -> Image.Image: - params = params.with_args(**kwargs) - upscale = upscale.with_args(**kwargs) - source = stage_source or source - logger.info( - "upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt - ) +class UpscaleStableDiffusionStage: + def run( + self, + job: WorkerContext, + server: ServerContext, + _stage: StageParams, + params: ImageParams, + source: Image.Image, + *, + upscale: UpscaleParams, + stage_source: Optional[Image.Image] = None, + callback: Optional[ProgressCallback] = None, + **kwargs, + ) -> Image.Image: + params = params.with_args(**kwargs) + upscale = upscale.with_args(**kwargs) + source = stage_source or source + logger.info( + "upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt + ) - prompt_pairs, _loras, _inversions = parse_prompt(params) + prompt_pairs, _loras, _inversions = parse_prompt(params) - pipeline = load_pipeline( - server, - params, - "upscale", - job.get_device(), - model=path.join(server.model_path, upscale.upscale_model), - ) - generator = torch.manual_seed(params.seed) + pipeline = load_pipeline( + server, + params, + "upscale", + job.get_device(), + model=path.join(server.model_path, upscale.upscale_model), + ) + generator = torch.manual_seed(params.seed) - prompt_embeds = encode_prompt( - pipeline, - prompt_pairs, - num_images_per_prompt=params.batch, - do_classifier_free_guidance=params.do_cfg(), - ) - pipeline.unet.set_prompts(prompt_embeds) + prompt_embeds = encode_prompt( + pipeline, + prompt_pairs, + num_images_per_prompt=params.batch, + do_classifier_free_guidance=params.do_cfg(), + ) + pipeline.unet.set_prompts(prompt_embeds) - return pipeline( - params.prompt, - source, - generator=generator, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_inference_steps=params.steps, - eta=params.eta, - noise_level=upscale.denoise, - callback=callback, - ).images[0] + return pipeline( + params.prompt, + source, + generator=generator, + guidance_scale=params.cfg, + negative_prompt=params.negative_prompt, + num_inference_steps=params.steps, + eta=params.eta, + noise_level=upscale.denoise, + callback=callback, + ).images[0] diff --git a/api/onnx_web/chain/upscale_swinir.py b/api/onnx_web/chain/upscale_swinir.py index aaf466cc..47d0dbaa 100644 --- a/api/onnx_web/chain/upscale_swinir.py +++ b/api/onnx_web/chain/upscale_swinir.py @@ -14,107 +14,116 @@ from ..worker import WorkerContext logger = getLogger(__name__) -def load_swinir( - server: ServerContext, - _stage: StageParams, - upscale: UpscaleParams, - device: DeviceParams, -): - # must be within the load function for patch to take effect - model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) - cache_key = (model_path,) - cache_pipe = server.cache.get("swinir", cache_key) +class UpscaleSwinIRStage: + max_tile = 64 - if cache_pipe is not None: - logger.info("reusing existing SwinIR pipeline") - return cache_pipe + def load( + self, + server: ServerContext, + _stage: StageParams, + upscale: UpscaleParams, + device: DeviceParams, + ): + # must be within the load function for patch to take effect + model_path = path.join(server.model_path, "%s.onnx" % (upscale.upscale_model)) + cache_key = (model_path,) + cache_pipe = server.cache.get("swinir", cache_key) - logger.debug("loading SwinIR model from %s", model_path) + if cache_pipe is not None: + logger.info("reusing existing SwinIR pipeline") + return cache_pipe - pipe = OnnxModel( - server, - model_path, - provider=device.ort_provider(), - sess_options=device.sess_options(), - ) + logger.debug("loading SwinIR model from %s", model_path) - server.cache.set("swinir", cache_key, pipe) - run_gc([device]) + pipe = OnnxModel( + server, + model_path, + provider=device.ort_provider(), + sess_options=device.sess_options(), + ) - return pipe + server.cache.set("swinir", cache_key, pipe) + run_gc([device]) + return pipe -def upscale_swinir( - job: WorkerContext, - server: ServerContext, - stage: StageParams, - _params: ImageParams, - source: Image.Image, - *, - upscale: UpscaleParams, - stage_source: Optional[Image.Image] = None, - **kwargs, -) -> Image.Image: - upscale = upscale.with_args(**kwargs) - source = stage_source or source + def run( + self, + job: WorkerContext, + server: ServerContext, + stage: StageParams, + _params: ImageParams, + source: Image.Image, + *, + upscale: UpscaleParams, + stage_source: Optional[Image.Image] = None, + **kwargs, + ) -> Image.Image: + upscale = upscale.with_args(**kwargs) + source = stage_source or source - if upscale.upscale_model is None: - logger.warn("no correction model given, skipping") - return source + if upscale.upscale_model is None: + logger.warn("no correction model given, skipping") + return source - logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model) - device = job.get_device() - swinir = load_swinir(server, stage, upscale, device) + logger.info("correcting faces with SwinIR model: %s", upscale.upscale_model) + device = job.get_device() + swinir = self.load(server, stage, upscale, device) - # TODO: add support for other sizes - tile_size = (64, 64) - tile_x = source.width // tile_size[0] - tile_y = source.height // tile_size[1] + # TODO: add support for other sizes + tile_size = (64, 64) + tile_x = source.width // tile_size[0] + tile_y = source.height // tile_size[1] - # TODO: add support for grayscale (1-channel) images - image = np.array(source) / 255.0 - image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) - image = np.expand_dims(image, axis=0) - logger.info("SwinIR input shape: %s", image.shape) + # TODO: add support for grayscale (1-channel) images + image = np.array(source) / 255.0 + image = image[:, :, [2, 1, 0]].astype(np.float32).transpose((2, 0, 1)) + image = np.expand_dims(image, axis=0) + logger.info("SwinIR input shape: %s", image.shape) - scale = upscale.outscale - dest = np.zeros( - (image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale) - ) - logger.info("SwinIR output shape: %s", dest.shape) - - for x in range(tile_x): - for y in range(tile_y): - xt = x * tile_size[0] - yt = y * tile_size[1] - - ix1 = xt - ix2 = xt + tile_size[0] - iy1 = yt - iy2 = yt + tile_size[1] - logger.info( - "running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)", - ix1, - ix2, - iy1, - iy2, - ix1 * scale, - ix2 * scale, - iy1 * scale, - iy2 * scale, + scale = upscale.outscale + dest = np.zeros( + ( + image.shape[0], + image.shape[1], + image.shape[2] * scale, + image.shape[3] * scale, ) + ) + logger.info("SwinIR output shape: %s", dest.shape) - dest[ - :, - :, - ix1 * scale : ix2 * scale, - iy1 * scale : iy2 * scale, - ] = swinir(image[:, :, ix1:ix2, iy1:iy2]) + for x in range(tile_x): + for y in range(tile_y): + xt = x * tile_size[0] + yt = y * tile_size[1] - dest = np.clip(np.squeeze(dest, axis=0), 0, 1) - dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) - dest = (dest * 255.0).round().astype(np.uint8) + ix1 = xt + ix2 = xt + tile_size[0] + iy1 = yt + iy2 = yt + tile_size[1] + logger.info( + "running SwinIR on tile: (%s, %s, %s, %s) -> (%s, %s, %s, %s)", + ix1, + ix2, + iy1, + iy2, + ix1 * scale, + ix2 * scale, + iy1 * scale, + iy2 * scale, + ) - output = Image.fromarray(dest, "RGB") - logger.info("output image size: %s x %s", output.width, output.height) - return output + dest[ + :, + :, + ix1 * scale : ix2 * scale, + iy1 * scale : iy2 * scale, + ] = swinir(image[:, :, ix1:ix2, iy1:iy2]) + + dest = np.clip(np.squeeze(dest, axis=0), 0, 1) + dest = dest[[2, 1, 0], :, :].transpose((1, 2, 0)) + dest = (dest * 255.0).round().astype(np.uint8) + + output = Image.fromarray(dest, "RGB") + logger.info("output image size: %s x %s", output.width, output.height) + return output diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index c66b9c84..d0fd239e 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -4,13 +4,14 @@ from typing import Any, List, Optional from PIL import Image from ..chain import ( - blend_img2img, - blend_mask, - source_txt2img, - upscale_highres, - upscale_outpaint, + BlendImg2ImgStage, + BlendMaskStage, + ChainPipeline, + SourceTxt2ImgStage, + UpscaleHighresStage, + UpscaleOutpaintStage, ) -from ..chain.base import ChainPipeline +from ..chain.upscale import split_upscale, stage_upscale_correction from ..output import save_image from ..params import ( Border, @@ -24,7 +25,6 @@ from ..server import ServerContext from ..server.load import get_source_filters from ..utils import run_gc, show_system_toast from ..worker import WorkerContext -from .upscale import split_upscale, stage_upscale_correction from .utils import parse_prompt logger = getLogger(__name__) @@ -43,7 +43,7 @@ def run_txt2img_pipeline( chain = ChainPipeline() stage = StageParams() chain.stage( - source_txt2img, + SourceTxt2ImgStage(), stage, size=size, ) @@ -61,7 +61,7 @@ def run_txt2img_pipeline( # apply highres for _i in range(highres.iterations): chain.stage( - upscale_highres, + UpscaleHighresStage(), StageParams( outscale=highres.scale, ), @@ -125,7 +125,7 @@ def run_img2img_pipeline( chain = ChainPipeline() stage = StageParams() chain.stage( - blend_img2img, + BlendImg2ImgStage(), stage, strength=strength, ) @@ -144,7 +144,7 @@ def run_img2img_pipeline( if params.loopback > 0: for _i in range(params.loopback): chain.stage( - blend_img2img, + BlendImg2ImgStage(), stage, strength=strength, ) @@ -153,7 +153,7 @@ def run_img2img_pipeline( if highres.iterations > 0: for _i in range(highres.iterations): chain.stage( - upscale_highres, + UpscaleHighresStage(), stage, highres=highres, upscale=upscale, @@ -223,7 +223,7 @@ def run_inpaint_pipeline( chain = ChainPipeline() stage = StageParams(tile_order=tile_order) chain.stage( - upscale_outpaint, + UpscaleOutpaintStage(), stage, border=border, stage_mask=mask, @@ -234,7 +234,7 @@ def run_inpaint_pipeline( # apply highres chain.stage( - upscale_highres, + UpscaleHighresStage(), stage, highres=highres, upscale=upscale, @@ -300,7 +300,7 @@ def run_upscale_pipeline( # apply highres chain.stage( - upscale_highres, + UpscaleHighresStage(), stage, highres=highres, upscale=upscale, @@ -353,7 +353,7 @@ def run_blend_pipeline( # set up the chain pipeline and base stage chain = ChainPipeline() stage = StageParams() - chain.stage(blend_mask, stage, stage_source=sources[1], stage_mask=mask) + chain.stage(BlendMaskStage(), stage, stage_source=sources[1], stage_mask=mask) # apply upscaling and correction stage_upscale_correction( diff --git a/api/onnx_web/image/utils.py b/api/onnx_web/image/utils.py index ee8dc59e..e3fe7c43 100644 --- a/api/onnx_web/image/utils.py +++ b/api/onnx_web/image/utils.py @@ -1,6 +1,4 @@ -from typing import Tuple, Union - -from PIL import Image, ImageChops, ImageOps +from PIL import Image, ImageChops from ..params import Border, Size from .mask_filter import mask_filter_none diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 49d2be93..fdcae5b3 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -356,12 +356,12 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): pipeline = ChainPipeline() for stage_data in data.get("stages", []): - callback = CHAIN_STAGES[stage_data.get("type")] + stage_class = CHAIN_STAGES[stage_data.get("type")] kwargs = stage_data.get("params", {}) - logger.info("request stage: %s, %s", callback.__name__, kwargs) + logger.info("request stage: %s, %s", stage_class.__name__, kwargs) stage = StageParams( - stage_data.get("name", callback.__name__), + stage_data.get("name", stage_class.__name__), tile_size=get_size(kwargs.get("tile_size")), outscale=get_and_clamp_int(kwargs, "outscale", 1, 4), ) @@ -399,7 +399,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): mask = Image.open(BytesIO(mask_file.read())).convert("RGB") kwargs["stage_mask"] = mask - pipeline.append((callback, stage, kwargs)) + pipeline.append((stage_class(), stage, kwargs)) logger.info("running chain pipeline with %s stages", len(pipeline.stages))