diff --git a/api/onnx_web/chain/__init__.py b/api/onnx_web/chain/__init__.py index edda66d2..e0e23a30 100644 --- a/api/onnx_web/chain/__init__.py +++ b/api/onnx_web/chain/__init__.py @@ -1,6 +1,5 @@ from .base import ChainPipeline, PipelineStage, StageParams from .blend_img2img import BlendImg2ImgStage -from .blend_inpaint import BlendInpaintStage from .blend_linear import BlendLinearStage from .blend_mask import BlendMaskStage from .correct_codeformer import CorrectCodeformerStage @@ -23,7 +22,7 @@ from .upscale_swinir import UpscaleSwinIRStage CHAIN_STAGES = { "blend-img2img": BlendImg2ImgStage, - "blend-inpaint": BlendInpaintStage, + "blend-inpaint": UpscaleOutpaintStage, "blend-linear": BlendLinearStage, "blend-mask": BlendMaskStage, "correct-codeformer": CorrectCodeformerStage, diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 566d938d..3467e549 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -146,7 +146,7 @@ class ChainPipeline: ) def stage_tile( - source_tile: Image.Image, tile_mask: Image.Image, _dims + source_tile: Image.Image, tile_mask: Image.Image, dims: Tuple[int, int, int] ) -> Image.Image: output_tile = stage_pipe.run( job, @@ -156,6 +156,7 @@ class ChainPipeline: [source_tile], tile_mask=tile_mask, callback=callback, + dims=dims, **kwargs, )[0] diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py deleted file mode 100644 index 8c15d4cf..00000000 --- a/api/onnx_web/chain/blend_inpaint.py +++ /dev/null @@ -1,140 +0,0 @@ -from logging import getLogger -from typing import Callable, List, Optional, Tuple - -import numpy as np -import torch -from PIL import Image - -from ..diffusers.load import load_pipeline -from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt -from ..image import expand_image, mask_filter_none, noise_source_histogram -from ..output import save_image -from ..params import Border, ImageParams, Size, SizeChart, StageParams -from ..server import ServerContext -from ..utils import is_debug -from ..worker import ProgressCallback, WorkerContext -from .stage import BaseStage -from .tile import process_tile_order - -logger = getLogger(__name__) - - -class BlendInpaintStage(BaseStage): - def run( - self, - job: WorkerContext, - server: ServerContext, - stage: StageParams, - params: ImageParams, - sources: List[Image.Image], - *, - expand: Border, - stage_source: Optional[Image.Image] = None, - stage_mask: Optional[Image.Image] = None, - fill_color: str = "white", - mask_filter: Callable = mask_filter_none, - noise_source: Callable = noise_source_histogram, - callback: Optional[ProgressCallback] = None, - **kwargs, - ) -> List[Image.Image]: - params = params.with_args(**kwargs) - expand = expand.with_args(**kwargs) - logger.info( - "blending image using inpaint, %s steps: %s", params.steps, params.prompt - ) - - prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( - params - ) - pipe_type = params.get_valid_pipeline("inpaint") - pipe = load_pipeline( - server, - params, - pipe_type, - job.get_device(), - inversions=inversions, - loras=loras, - ) - - outputs = [] - for source in sources: - if stage_mask is None: - # if no mask was provided, keep the full source image - stage_mask = Image.new("RGB", source.size, "black") - - source, stage_mask, noise, _full_dims = expand_image( - source, - stage_mask, - expand, - fill=fill_color, - noise_source=noise_source, - mask_filter=mask_filter, - ) - - if is_debug(): - save_image(server, "last-source.png", source) - save_image(server, "last-mask.png", stage_mask) - save_image(server, "last-noise.png", noise) - - def outpaint(tile_source: Image.Image, dims: Tuple[int, int, int]): - left, top, tile = dims - size = Size(*tile_source.size) - tile_mask = stage_mask.crop((left, top, left + tile, top + tile)) - - if is_debug(): - save_image(server, "tile-source.png", tile_source) - save_image(server, "tile-mask.png", tile_mask) - - latents = get_latents_from_seed(params.seed, size) - if params.lpw(): - logger.debug("using LPW pipeline for inpaint") - rng = torch.manual_seed(params.seed) - result = pipe.inpaint( - prompt, - generator=rng, - guidance_scale=params.cfg, - height=size.height, - image=tile_source, - latents=latents, - mask_image=tile_mask, - negative_prompt=negative_prompt, - num_inference_steps=params.steps, - width=size.width, - eta=params.eta, - callback=callback, - ) - else: - # encode and record alternative prompts outside of LPW - prompt_embeds = encode_prompt( - pipe, prompt_pairs, params.batch, params.do_cfg() - ) - pipe.unet.set_prompts(prompt_embeds) - - rng = np.random.RandomState(params.seed) - result = pipe( - prompt, - generator=rng, - guidance_scale=params.cfg, - height=size.height, - image=tile_source, - latents=latents, - mask_image=stage_mask, - negative_prompt=negative_prompt, - num_inference_steps=params.steps, - width=size.width, - eta=params.eta, - callback=callback, - ) - - return result.images[0] - - outputs.append( - process_tile_order( - stage.tile_order, - source, - SizeChart.auto, - 1, - [outpaint], - overlap=params.overlap, - ) - ) diff --git a/api/onnx_web/chain/source_txt2img.py b/api/onnx_web/chain/source_txt2img.py index 18058ad9..2442060a 100644 --- a/api/onnx_web/chain/source_txt2img.py +++ b/api/onnx_web/chain/source_txt2img.py @@ -1,12 +1,12 @@ from logging import getLogger -from typing import Optional +from typing import Optional, Tuple import numpy as np import torch from PIL import Image from ..diffusers.load import load_pipeline -from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt +from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext @@ -26,8 +26,10 @@ class SourceTxt2ImgStage(BaseStage): params: ImageParams, _source: Image.Image, *, + dims: Tuple[int, int, int], size: Size, callback: Optional[ProgressCallback] = None, + latents: Optional[np.ndarray] = None, **kwargs, ) -> Image.Image: params = params.with_args(**kwargs) @@ -47,15 +49,23 @@ class SourceTxt2ImgStage(BaseStage): ) tile_size = params.tiles - if max(size) > tile_size: - latent_size = Size(tile_size, tile_size) + + # generate new latents or slice existing + if latents is None: + if max(size) > tile_size: + latent_size = Size(tile_size, tile_size) + pipe_width = pipe_height = tile_size + else: + latent_size = Size(size.width, size.height) + pipe_width = size.width + pipe_height = size.height + + # generate new latents latents = get_latents_from_seed(params.seed, latent_size, params.batch) - pipe_width = pipe_height = tile_size else: - latent_size = Size(size.width, size.height) - latents = get_latents_from_seed(params.seed, latent_size, params.batch) - pipe_width = size.width - pipe_height = size.height + # slice existing latents + latents = get_tile_latents(latents, dims, size) + pipe_width, pipe_height, _tile_size = dims pipe_type = params.get_valid_pipeline("txt2img") pipe = load_pipeline( diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index b41ec82b..0a37d878 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -6,7 +6,7 @@ import torch from PIL import Image from ..diffusers.load import load_pipeline -from ..diffusers.utils import encode_prompt, get_latents_from_seed, parse_prompt +from ..diffusers.utils import encode_prompt, get_latents_from_seed, get_tile_latents, parse_prompt from ..image import mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams @@ -28,15 +28,17 @@ class UpscaleOutpaintStage(BaseStage): stage: StageParams, params: ImageParams, sources: List[Image.Image], - tile_mask: Image.Image, *, border: Border, - stage_source: Optional[Image.Image] = None, - stage_mask: Optional[Image.Image] = None, + dims: Tuple[int, int, int], + tile_mask: Image.Image, fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, + latents: Optional[np.ndarray] = None, callback: Optional[ProgressCallback] = None, + stage_source: Optional[Image.Image] = None, + stage_mask: Optional[Image.Image] = None, **kwargs, ) -> List[Image.Image]: prompt_pairs, loras, inversions, (prompt, negative_prompt) = parse_prompt( @@ -64,18 +66,25 @@ class UpscaleOutpaintStage(BaseStage): outputs.append(source) continue - source_width, source_height = source.size - source_size = Size(source_width, source_height) + size = Size(*source.size) tile_size = params.tiles - if max(source_size) > tile_size: - latent_size = Size(tile_size, tile_size) - latents = get_latents_from_seed(params.seed, latent_size) - pipe_width = pipe_height = tile_size + + # generate new latents or slice existing + if latents is None: + if max(size) > tile_size: + latent_size = Size(tile_size, tile_size) + pipe_width = pipe_height = tile_size + else: + latent_size = Size(size.width, size.height) + pipe_width = size.width + pipe_height = size.height + + # generate new latents + latents = get_latents_from_seed(params.seed, latent_size, params.batch) else: - latent_size = Size(source_size.width, source_size.height) - latents = get_latents_from_seed(params.seed, latent_size) - pipe_width = source_size.width - pipe_height = source_size.height + # slice existing latents + latents = get_tile_latents(latents, dims, size) + pipe_width, pipe_height, _tile_size = dims if params.lpw(): logger.debug("using LPW pipeline for inpaint") diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 2710df9e..8cd780cd 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -27,7 +27,7 @@ from ..server import ServerContext from ..server.load import get_source_filters from ..utils import is_debug, run_gc, show_system_toast from ..worker import WorkerContext -from .utils import parse_prompt +from .utils import get_latents_from_seed, parse_prompt logger = getLogger(__name__) @@ -81,8 +81,9 @@ def run_txt2img_pipeline( ) # run and save + latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = job.get_progress_callback() - images = chain(job, server, params, [], callback=progress) + images = chain.run(job, server, params, [], callback=progress, latents=latents) _pairs, loras, inversions, _rest = parse_prompt(params) @@ -287,8 +288,9 @@ def run_inpaint_pipeline( ) # run and save + latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = job.get_progress_callback() - images = chain(job, server, params, [source], callback=progress) + images = chain(job, server, params, [source], callback=progress, latents=latents) _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 7e7100c8..75b2fc10 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -17,7 +17,7 @@ from ..diffusers.run import ( ) from ..diffusers.utils import replace_wildcards from ..output import json_params, make_output_name -from ..params import Border, StageParams, TileOrder, UpscaleParams +from ..params import Border, Size, StageParams, TileOrder, UpscaleParams from ..transformers.run import run_txt2txt_pipeline from ..utils import ( base_join, @@ -163,8 +163,9 @@ def img2img(server: ServerContext, pool: DevicePoolExecutor): return error_reply("source image is required") source = Image.open(BytesIO(source_file.read())).convert("RGB") + size = Size(source.width, source.height) - device, params, size = pipeline_from_request(server, "img2img") + device, params, _size = pipeline_from_request(server, "img2img") upscale = upscale_from_request() highres = highres_from_request() source_filter = get_from_list( @@ -249,12 +250,14 @@ def inpaint(server: ServerContext, pool: DevicePoolExecutor): return error_reply("mask image is required") source = Image.open(BytesIO(source_file.read())).convert("RGB") + size = Size(source.width, source.height) + mask_top_layer = Image.open(BytesIO(mask_file.read())).convert("RGBA") mask = Image.new("RGBA", mask_top_layer.size, color=(0, 0, 0, 255)) mask.alpha_composite(mask_top_layer) mask.convert(mode="L") - device, params, size = pipeline_from_request(server, "inpaint") + device, params, _size = pipeline_from_request(server, "inpaint") expand = border_from_request() upscale = upscale_from_request() highres = highres_from_request()