From ac1f7449bb50463b10635a388431d0d1a324a954 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 22:11:44 -0600 Subject: [PATCH] fix(api): use stage source when available --- api/onnx_web/chain/blend_img2img.py | 2 ++ api/onnx_web/chain/blend_inpaint.py | 18 +++++++------ api/onnx_web/chain/blend_mask.py | 8 +++--- api/onnx_web/chain/correct_codeformer.py | 4 ++- api/onnx_web/chain/correct_gfpgan.py | 2 ++ api/onnx_web/chain/persist_disk.py | 3 +++ api/onnx_web/chain/persist_s3.py | 3 +++ api/onnx_web/chain/reduce_crop.py | 3 +++ api/onnx_web/chain/reduce_thumbnail.py | 2 ++ api/onnx_web/chain/source_noise.py | 2 ++ api/onnx_web/chain/source_txt2img.py | 4 +-- api/onnx_web/chain/upscale_outpaint.py | 26 +++++++++---------- api/onnx_web/chain/upscale_resrgan.py | 2 ++ .../chain/upscale_stable_diffusion.py | 2 ++ api/onnx_web/diffusion/run.py | 4 +-- 15 files changed, 55 insertions(+), 30 deletions(-) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index 525f2507..0baa3153 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -20,9 +20,11 @@ def blend_img2img( source: Image.Image, *, callback: ProgressCallback = None, + stage_source: Image.Image, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) + source = stage_source or source logger.info( "blending image using img2img, %s steps: %s", params.steps, params.prompt ) diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 1234aef7..bd51eb3e 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -25,7 +25,8 @@ def blend_inpaint( source: Image.Image, *, expand: Border, - mask: Optional[Image.Image] = None, + stage_source: Optional[Image.Image] = None, + stage_mask: Optional[Image.Image] = None, fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, @@ -34,17 +35,18 @@ def blend_inpaint( ) -> Image.Image: params = params.with_args(**kwargs) expand = expand.with_args(**kwargs) + source = source or stage_source logger.info( "blending image using inpaint, %s steps: %s", params.steps, params.prompt ) - if mask is None: + if stage_mask is None: # if no mask was provided, keep the full source image - mask = Image.new("RGB", source.size, "black") + stage_mask = Image.new("RGB", source.size, "black") - source, mask, noise, _full_dims = expand_image( + source, stage_mask, noise, _full_dims = expand_image( source, - mask, + stage_mask, expand, fill=fill_color, noise_source=noise_source, @@ -53,13 +55,13 @@ def blend_inpaint( if is_debug(): save_image(server, "last-source.png", source) - save_image(server, "last-mask.png", mask) + save_image(server, "last-mask.png", stage_mask) save_image(server, "last-noise.png", noise) def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): left, top, tile = dims size = Size(*tile_source.size) - tile_mask = mask.crop((left, top, left + tile, top + tile)) + tile_mask = stage_mask.crop((left, top, left + tile, top + tile)) if is_debug(): save_image(server, "tile-source.png", tile_source) @@ -100,7 +102,7 @@ def blend_inpaint( height=size.height, image=tile_source, latents=latents, - mask_image=mask, + mask_image=stage_mask, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, width=size.width, diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 1fbe80ef..fc24d23c 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -20,18 +20,18 @@ def blend_mask( _params: ImageParams, *, sources: Optional[List[Image.Image]] = None, - mask: Optional[Image.Image] = None, + stage_mask: Optional[Image.Image] = None, _callback: ProgressCallback = None, **kwargs, ) -> Image.Image: logger.info("blending image using mask") - mult_mask = Image.new("RGBA", mask.size, color="black") - mult_mask.alpha_composite(mask) + mult_mask = Image.new("RGBA", stage_mask.size, color="black") + mult_mask.alpha_composite(stage_mask) mult_mask = mult_mask.convert("L") if is_debug(): - save_image(server, "last-mask.png", mask) + save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mult-mask.png", mult_mask) resized = [ diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index f6b26203..6b4e235d 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -24,8 +24,10 @@ def correct_codeformer( # must be within the load function for patch to take effect from codeformer import CodeFormer + source = stage_source or source + upscale = upscale.with_args(**kwargs) device = job.get_device() pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str()) - return pipe(stage_source or source) + return pipe(source) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index c9efc96a..6796c1ba 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -53,9 +53,11 @@ def correct_gfpgan( source: Image.Image, *, upscale: UpscaleParams, + stage_source: Image.Image, **kwargs, ) -> Image.Image: upscale = upscale.with_args(**kwargs) + source = stage_source or source if upscale.correction_model is None: logger.warn("no face model given, skipping") diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index a8162607..9a5f0cd0 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -17,8 +17,11 @@ def persist_disk( source: Image.Image, *, output: str, + stage_source: Image.Image, **kwargs, ) -> Image.Image: + source = stage_source or source + dest = save_image(server, output, source) logger.info("saved image to %s", dest) return source diff --git a/api/onnx_web/chain/persist_s3.py b/api/onnx_web/chain/persist_s3.py index c3a889fa..bf3682bf 100644 --- a/api/onnx_web/chain/persist_s3.py +++ b/api/onnx_web/chain/persist_s3.py @@ -21,8 +21,11 @@ def persist_s3( bucket: str, endpoint_url: str = None, profile_name: str = None, + stage_source: Image.Image = None, **kwargs, ) -> Image.Image: + source = stage_source or source + session = Session(profile_name=profile_name) s3 = session.client("s3", endpoint_url=endpoint_url) diff --git a/api/onnx_web/chain/reduce_crop.py b/api/onnx_web/chain/reduce_crop.py index 43debc83..226f6cf2 100644 --- a/api/onnx_web/chain/reduce_crop.py +++ b/api/onnx_web/chain/reduce_crop.py @@ -17,8 +17,11 @@ def reduce_crop( *, origin: Size, size: Size, + stage_source: Image.Image = None, **kwargs, ) -> Image.Image: + source = stage_source or source + image = source.crop((origin.width, origin.height, size.width, size.height)) logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) return image diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index c5d143b5..0037084c 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -16,8 +16,10 @@ def reduce_thumbnail( source: Image.Image, *, size: Size, + stage_source: Image.Image, **kwargs, ) -> Image.Image: + source = stage_source or source image = source.copy() # TODO: should use a call to valid_image diff --git a/api/onnx_web/chain/source_noise.py b/api/onnx_web/chain/source_noise.py index 8135de81..76cc4f73 100644 --- a/api/onnx_web/chain/source_noise.py +++ b/api/onnx_web/chain/source_noise.py @@ -18,8 +18,10 @@ def source_noise( *, size: Size, noise_source: Callable, + stage_source: Image.Image, **kwargs, ) -> Image.Image: + source = stage_source or source logger.info("generating image from noise source") if source is not None: diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index b5e0dd1c..1fb4ee6a 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -17,7 +17,7 @@ def source_txt2img( server: ServerContext, _stage: StageParams, params: ImageParams, - source: Image.Image, + _source: Image.Image, *, size: Size, callback: ProgressCallback = None, @@ -29,7 +29,7 @@ def source_txt2img( "generating image using txt2img, %s steps: %s", params.steps, params.prompt ) - if source is not None: + if "stage_source" in kwargs: logger.warn( "a source image was passed to a txt2img stage, but will be discarded" ) diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index 14a20cf3..6d8d9587 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple import numpy as np import torch @@ -25,47 +25,47 @@ def upscale_outpaint( source: Image.Image, *, border: Border, - prompt: str = None, - mask: Image.Image = None, + stage_source: Optional[Image.Image] = None, + stage_mask: Optional[Image.Image] = None, fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, callback: ProgressCallback = None, **kwargs, ) -> Image.Image: - prompt = prompt or params.prompt + source = stage_source or source logger.info("upscaling image by expanding borders: %s", border) margin_x = float(max(border.left, border.right)) margin_y = float(max(border.top, border.bottom)) overlap = min(margin_x / source.width, margin_y / source.height) - if mask is None: + if stage_mask is None: # if no mask was provided, keep the full source image - mask = Image.new("RGB", source.size, "black") + stage_mask = Image.new("RGB", source.size, "black") - source, mask, noise, full_dims = expand_image( + source, stage_mask, noise, full_dims = expand_image( source, - mask, + stage_mask, border, fill=fill_color, noise_source=noise_source, mask_filter=mask_filter, ) - draw_mask = ImageDraw.Draw(mask) + draw_mask = ImageDraw.Draw(stage_mask) full_size = Size(*full_dims) full_latents = get_latents_from_seed(params.seed, full_size) if is_debug(): save_image(server, "last-source.png", source) - save_image(server, "last-mask.png", mask) + save_image(server, "last-mask.png", stage_mask) save_image(server, "last-noise.png", noise) def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): left, top, tile = dims size = Size(*tile_source.size) - tile_mask = mask.crop((left, top, left + tile, top + tile)) + tile_mask = stage_mask.crop((left, top, left + tile, top + tile)) if is_debug(): save_image(server, "tile-source.png", tile_source) @@ -86,7 +86,7 @@ def upscale_outpaint( result = pipe.inpaint( tile_source, tile_mask, - prompt, + params.prompt, generator=rng, guidance_scale=params.cfg, height=size.height, @@ -99,7 +99,7 @@ def upscale_outpaint( else: rng = np.random.RandomState(params.seed) result = pipe( - prompt, + params.prompt, tile_source, generator=rng, guidance_scale=params.cfg, diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index a57a3100..a97f7ecd 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -103,8 +103,10 @@ def upscale_resrgan( source: Image.Image, *, upscale: UpscaleParams, + stage_source: Image.Image = None, **kwargs, ) -> Image.Image: + source = stage_source or source logger.info("upscaling image with Real ESRGAN: x%s", upscale.scale) output = np.array(source) diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index 851bdb86..c5356e11 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -69,11 +69,13 @@ def upscale_stable_diffusion( source: Image.Image, *, upscale: UpscaleParams, + stage_source: Image.Image = None, callback: ProgressCallback = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) upscale = upscale.with_args(**kwargs) + source = stage_source or source logger.info( "upscaling with Stable Diffusion, %s steps: %s", params.steps, params.prompt ) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index d42ff51c..9ac97942 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -188,7 +188,7 @@ def run_inpaint_pipeline( params, source, border=border, - mask=mask, + stage_mask=mask, fill_color=fill_color, mask_filter=mask_filter, noise_source=noise_source, @@ -255,7 +255,7 @@ def run_blend_pipeline( stage, params, sources=sources, - mask=mask, + stage_mask=mask, callback=progress, ) image = image.convert("RGB")