diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index fab88910..7a6873ad 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -64,7 +64,6 @@ def blend_img2img( image=source, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, - strength=strength, callback=callback, **pipe_params, ) @@ -81,7 +80,6 @@ def blend_img2img( image=source, negative_prompt=params.negative_prompt, num_inference_steps=params.steps, - strength=strength, callback=callback, **pipe_params, ) diff --git a/api/onnx_web/chain/upscale_highres.py b/api/onnx_web/chain/upscale_highres.py index 53b15c73..23444a48 100644 --- a/api/onnx_web/chain/upscale_highres.py +++ b/api/onnx_web/chain/upscale_highres.py @@ -1,13 +1,11 @@ from logging import getLogger from typing import Any, Optional -import numpy as np -import torch from PIL import Image -from ..diffusers.load import load_pipeline +from ..chain.base import ChainPipeline +from ..chain.img2img import blend_img2img from ..diffusers.upscale import append_upscale_correction -from ..diffusers.utils import parse_prompt from ..params import HighresParams, ImageParams, StageParams, UpscaleParams from ..server import ServerContext from ..worker import WorkerContext @@ -30,25 +28,12 @@ def upscale_highres( callback: Optional[ProgressCallback] = None, **kwargs, ) -> Image.Image: - image = stage_source or source + source = stage_source or source if highres.scale <= 1: - return image - - # load img2img pipeline once - pipe_type = params.get_valid_pipeline("img2img") - logger.debug("using %s pipeline for highres", pipe_type) - - _prompt_pairs, loras, inversions = parse_prompt(params) - highres_pipe = pipeline or load_pipeline( - server, - params, - pipe_type, - job.get_device(), - inversions=inversions, - loras=loras, - ) + return source + chain = ChainPipeline() scaled_size = (source.width * highres.scale, source.height * highres.scale) # TODO: upscaling within the same stage prevents tiling from happening and causes OOM @@ -60,7 +45,7 @@ def upscale_highres( source = source.resize(scaled_size, resample=Image.Resampling.LANCZOS) else: logger.debug("using upscaling pipeline for highres") - upscale = append_upscale_correction( + append_upscale_correction( StageParams(), params, upscale=upscale.with_args( @@ -68,41 +53,24 @@ def upscale_highres( scale=highres.scale, outscale=highres.scale, ), - ) - source = upscale( - job, - server, - source, - callback=callback, + chain=chain, ) - if pipe_type == "lpw": - rng = torch.manual_seed(params.seed) - result = highres_pipe.img2img( - source, - params.prompt, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=1, - num_inference_steps=highres.steps, - strength=highres.strength, - eta=params.eta, - callback=callback, + chain.append( + ( + blend_img2img, + StageParams(), + { + "overlap": params.overlap, + "strength": highres.strength, + }, ) - return result.images[0] - else: - rng = np.random.RandomState(params.seed) - result = highres_pipe( - params.prompt, - source, - generator=rng, - guidance_scale=params.cfg, - negative_prompt=params.negative_prompt, - num_images_per_prompt=1, - num_inference_steps=highres.steps, - strength=highres.strength, - eta=params.eta, - callback=callback, - ) - return result.images[0] + ) + + return chain( + job, + server, + params, + source, + callback=callback, + ) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index ad6d0bc9..18d51cc3 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -63,16 +63,17 @@ def run_txt2img_pipeline( ) # apply highres - chain.append( - ( - upscale_highres, - stage, - { - "highres": highres, - "upscale": upscale, - }, + for _i in range(highres.iterations): + chain.append( + ( + upscale_highres, + stage, + { + "highres": highres, + "upscale": upscale, + }, + ) ) - ) # apply upscaling and correction, after highres append_upscale_correction( diff --git a/api/onnx_web/diffusers/upscale.py b/api/onnx_web/diffusers/upscale.py index 064b0f4b..caee1756 100644 --- a/api/onnx_web/diffusers/upscale.py +++ b/api/onnx_web/diffusers/upscale.py @@ -1,16 +1,13 @@ from logging import getLogger from typing import List, Optional, Tuple -from ..chain import ( - ChainPipeline, - PipelineStage, - correct_codeformer, - correct_gfpgan, - upscale_bsrgan, - upscale_resrgan, - upscale_stable_diffusion, - upscale_swinir, -) +from ..chain import ChainPipeline, PipelineStage +from ..chain.correct_codeformer import correct_codeformer +from ..chain.correct_gfpgan import correct_gfpgan +from ..chain.upscale_bsrgan import upscale_bsrgan +from ..chain.upscale_resrgan import upscale_resrgan +from ..chain.upscale_stable_diffusion import upscale_stable_diffusion +from ..chain.upscale_swinir import upscale_swinir from ..params import ImageParams, SizeChart, StageParams, UpscaleParams logger = getLogger(__name__) @@ -65,6 +62,9 @@ def append_upscale_correction( for stage, params in pre_stages: chain.append((stage, params)) + upscale_opts = { + "upscale": upscale, + } upscale_stage = None if upscale.scale > 1: if "bsrgan" in upscale.upscale_model: @@ -72,23 +72,23 @@ def append_upscale_correction( tile_size=stage.tile_size, outscale=upscale.outscale, ) - upscale_stage = (upscale_bsrgan, bsrgan_params, None) + upscale_stage = (upscale_bsrgan, bsrgan_params, upscale_opts) elif "esrgan" in upscale.upscale_model: esrgan_params = StageParams( tile_size=stage.tile_size, outscale=upscale.outscale, ) - upscale_stage = (upscale_resrgan, esrgan_params, None) + upscale_stage = (upscale_resrgan, esrgan_params, upscale_opts) elif "stable-diffusion" in upscale.upscale_model: mini_tile = min(SizeChart.mini, stage.tile_size) sd_params = StageParams(tile_size=mini_tile, outscale=upscale.outscale) - upscale_stage = (upscale_stable_diffusion, sd_params, None) + upscale_stage = (upscale_stable_diffusion, sd_params, upscale_opts) elif "swinir" in upscale.upscale_model: swinir_params = StageParams( tile_size=stage.tile_size, outscale=upscale.outscale, ) - upscale_stage = (upscale_swinir, swinir_params, None) + upscale_stage = (upscale_swinir, swinir_params, upscale_opts) else: logger.warn("unknown upscaling model: %s", upscale.upscale_model) diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 584ab886..10ba1d81 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -87,14 +87,14 @@ class Size: def tojson(self) -> Dict[str, int]: return { - "height": self.height, "width": self.width, + "height": self.height, } def with_args(self, **kwargs): return Size( - kwargs.get("height", self.height), kwargs.get("width", self.width), + kwargs.get("height", self.height), )