diff --git a/api/onnx_web/chain/highres.py b/api/onnx_web/chain/highres.py index 523b9a3a..2a43e051 100644 --- a/api/onnx_web/chain/highres.py +++ b/api/onnx_web/chain/highres.py @@ -57,7 +57,7 @@ def stage_highres( chain.stage( BlendImg2ImgStage(), - stage, + stage.with_args(outscale=1), overlap=params.vae_overlap, prompt_index=prompt_index + i, strength=highres.strength, diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index f600d06b..b88d9628 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -156,7 +156,6 @@ def blend_tiles( value = np.zeros(scaled_size) for left, top, tile_image in tiles: - # TODO: histogram equalization equalized = np.array(tile_image).astype(np.float32) mask = np.ones_like(equalized[:, :, 0]) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 5991c57a..f5bd3a06 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -64,21 +64,20 @@ def run_txt2img_pipeline( # apply upscaling and correction, before highres highres_size = params.unet_tile - stage = StageParams(tile_size=highres_size) - if params.is_panorama(): + if server.has_feature("panorama-highres"): + # run the whole highres pass with one panorama call + highres_size = tile_size * highres.scale + chain.stage( BlendDenoiseStage(), - stage, + StageParams(tile_size=highres_size), ) - if server.has_feature("panorama-highres"): - highres_size = tile_size * highres.scale - first_upscale, after_upscale = split_upscale(upscale) if first_upscale: stage_upscale_correction( - stage, + StageParams(outscale=first_upscale.outscale, tile_size=highres_size), params, chain=chain, upscale=first_upscale, @@ -86,7 +85,7 @@ def run_txt2img_pipeline( # apply highres stage_highres( - stage, + StageParams(outscale=highres.scale, tile_size=highres_size), params, highres, upscale, @@ -96,7 +95,7 @@ def run_txt2img_pipeline( # apply upscaling and correction, after highres stage_upscale_correction( - stage, + StageParams(outscale=after_upscale.outscale, tile_size=highres_size), params, chain=chain, upscale=after_upscale, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index efac3742..7d0ad48c 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -369,6 +369,17 @@ class StageParams: self.tile_order = tile_order self.tile_size = tile_size + def with_args( + self, + **kwargs, + ): + return StageParams( + name=kwargs.get("name", self.name), + outscale=kwargs.get("outscale", self.outscale), + tile_order=kwargs.get("tile_order", self.tile_order), + tile_size=kwargs.get("tile_size", self.tile_size), + ) + class UpscaleParams: def __init__( diff --git a/api/tests/test_diffusers/test_run.py b/api/tests/test_diffusers/test_run.py index e9988116..e4004e15 100644 --- a/api/tests/test_diffusers/test_run.py +++ b/api/tests/test_diffusers/test_run.py @@ -149,10 +149,11 @@ class TestTxt2ImgPipeline(unittest.TestCase): 3.0, 1, 1, + unet_tile=256, ), Size(256, 256), ["test-txt2img-highres.png"], - UpscaleParams("test"), + UpscaleParams("test", scale=2, outscale=2), HighresParams(True, 2, 0, 0), )