From 98fcc0752457beb9952ddbd928b593f2ebc1fc31 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Nov 2023 18:39:39 -0600 Subject: [PATCH] rewrite tile handling for image stacks --- api/onnx_web/chain/blend_linear.py | 5 +- api/onnx_web/chain/blend_mask.py | 3 +- api/onnx_web/chain/persist_disk.py | 4 +- api/onnx_web/chain/pipeline.py | 95 +++++---- api/onnx_web/chain/tile.py | 251 +++++++++++++----------- api/onnx_web/chain/upscale_bsrgan.py | 7 +- api/onnx_web/chain/upscale_simple.py | 8 +- api/onnx_web/diffusers/run.py | 20 +- api/tests/chain/test_source_noise.py | 2 +- api/tests/chain/test_source_s3.py | 2 +- api/tests/chain/test_source_url.py | 2 +- api/tests/chain/test_tile.py | 54 ++--- api/tests/chain/test_upscale_highres.py | 2 +- 13 files changed, 254 insertions(+), 201 deletions(-) diff --git a/api/onnx_web/chain/blend_linear.py b/api/onnx_web/chain/blend_linear.py index e4a98d9d..1b40a5fd 100644 --- a/api/onnx_web/chain/blend_linear.py +++ b/api/onnx_web/chain/blend_linear.py @@ -29,5 +29,8 @@ class BlendLinearStage(BaseStage): logger.info("blending source images using linear interpolation") return StageResult( - images=[Image.blend(source, stage_source, alpha) for source in sources.as_image()] + images=[ + Image.blend(source, stage_source, alpha) + for source in sources.as_image() + ] ) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 926331a3..4ebb1498 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -40,6 +40,7 @@ class BlendMaskStage(BaseStage): return StageResult( images=[ - Image.composite(stage_source, source, mult_mask) for source in sources.as_image() + Image.composite(stage_source, source, mult_mask) + for source in sources.as_image() ] ) diff --git a/api/onnx_web/chain/persist_disk.py b/api/onnx_web/chain/persist_disk.py index f7d988cc..7a2007ce 100644 --- a/api/onnx_web/chain/persist_disk.py +++ b/api/onnx_web/chain/persist_disk.py @@ -29,9 +29,7 @@ class PersistDiskStage(BaseStage): stage_source: Optional[Image.Image] = None, **kwargs, ) -> StageResult: - logger.info( - "persisting %s images to disk: %s", len(sources), output - ) + logger.info("persisting %s images to disk: %s", len(sources), output) for source, name in zip(sources, output): dest = save_image(server, name, source, params=params, size=size) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index 40a43ccf..50122d7d 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -12,8 +12,8 @@ from ..server import ServerContext from ..utils import is_debug, run_gc from ..worker import ProgressCallback, WorkerContext from .base import BaseStage -from .tile import needs_tile, process_tile_order from .result import StageResult +from .tile import needs_tile, process_tile_order logger = getLogger(__name__) @@ -163,60 +163,55 @@ class ChainPipeline: tile = min(stage_pipe.max_tile, stage_params.tile_size) # TODO: stage_sources will always be defined here - if stage_sources or must_tile: - stage_results = [] - for source in stage_sources.as_image(): - logger.info( - "image contains sources or is larger than tile size of %s, tiling stage", - tile, - ) + if must_tile: + logger.info( + "image contains sources or is larger than tile size of %s, tiling stage", + tile, + ) - def stage_tile( - source_tile: Image.Image, - tile_mask: Image.Image, - dims: Tuple[int, int, int], - ) -> StageResult: - for _i in range(worker.retries): - try: - tile_result = stage_pipe.run( - worker, - server, - stage_params, - per_stage_params, - StageResult(images=[source_tile]), - tile_mask=tile_mask, - callback=callback, - dims=dims, - **kwargs, - ) + def stage_tile( + source_tile: List[Image.Image], + tile_mask: Image.Image, + dims: Tuple[int, int, int], + ) -> List[Image.Image]: + for _i in range(worker.retries): + try: + tile_result = stage_pipe.run( + worker, + server, + stage_params, + per_stage_params, + StageResult(images=[source_tile]), + tile_mask=tile_mask, + callback=callback, + dims=dims, + **kwargs, + ) - if is_debug(): - for j, image in enumerate(tile_result.as_image()): - save_image(server, f"last-tile-{j}.png", image) + if is_debug(): + for j, image in enumerate(tile_result.as_image()): + save_image(server, f"last-tile-{j}.png", image) - # TODO: return whole result - return tile_result.as_image()[0] - except Exception: - worker.retries = worker.retries - 1 - logger.exception( - "error while running stage pipeline for tile, %s retries left", - worker.retries, - ) - server.cache.clear() - run_gc([worker.get_device()]) + return tile_result.as_image() + except Exception: + worker.retries = worker.retries - 1 + logger.exception( + "error while running stage pipeline for tile, %s retries left", + worker.retries, + ) + server.cache.clear() + run_gc([worker.get_device()]) - raise RetryException("exhausted retries on tile") + raise RetryException("exhausted retries on tile") - output = process_tile_order( - stage_params.tile_order, - source, - tile, - stage_params.outscale, - [stage_tile], - **kwargs, - ) - - stage_results.append(output) + stage_results = process_tile_order( + stage_params.tile_order, + stage_sources, + tile, + stage_params.outscale, + [stage_tile], + **kwargs, + ) stage_sources = StageResult(images=stage_results) else: diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index bc6bfdf5..aae71af1 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -2,7 +2,7 @@ import itertools from enum import Enum from logging import getLogger from math import ceil -from typing import Any, List, Optional, Protocol, Tuple +from typing import Any, Callable, List, Optional, Protocol, Tuple import numpy as np from PIL import Image @@ -17,12 +17,17 @@ from .result import StageResult logger = getLogger(__name__) +TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]] + + class TileCallback(Protocol): """ Definition for a tile job function. """ - def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult: + def __call__( + self, image: Image.Image, dims: Tuple[int, int, int] + ) -> List[Image.Image]: """ Run this stage against a single tile. """ @@ -33,6 +38,9 @@ def complete_tile( source: Image.Image, tile: int, ) -> Image.Image: + """ + TODO: clean up + """ if source is None: return source @@ -67,7 +75,7 @@ def needs_tile( return False -def get_tile_grads( +def make_tile_grads( left: int, top: int, tile: int, @@ -161,7 +169,7 @@ def blend_tiles( points = [-1, min(p1, p2), max(p1, p2), (tile * scale)] # gradient blending - grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height) + grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height) logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y) mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)] @@ -225,60 +233,18 @@ def blend_tiles( return Image.fromarray(np.uint8(pixels)) -def process_tile_grid( - source: Image.Image, - tile: int, - scale: int, - filters: List[TileCallback], - overlap: float = 0.0, - **kwargs, -) -> Image.Image: - width, height = kwargs.get("size", source.size if source else None) - - adj_tile = int(float(tile) * (1.0 - overlap)) - tiles_x = ceil(width / adj_tile) - tiles_y = ceil(height / adj_tile) - total = tiles_x * tiles_y - logger.debug( - "processing %s tiles (%s x %s) with adjusted size of %s, %s overlap", - total, - tiles_x, - tiles_y, - adj_tile, - overlap, - ) - - tiles: List[Tuple[int, int, Image.Image]] = [] - - for y in range(tiles_y): - for x in range(tiles_x): - idx = (y * tiles_x) + x - left = x * adj_tile - top = y * adj_tile - logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x) - - tile_image = ( - source.crop((left, top, left + tile, top + tile)) if source else None - ) - tile_image = complete_tile(tile_image, tile) - - for filter in filters: - tile_image = filter(tile_image, (left, top, tile)) - - tiles.append((left, top, tile_image)) - - return blend_tiles(tiles, scale, width, height, tile, overlap) - - -def process_tile_spiral( - source: Image.Image, +def process_tile_stack( + stack: StageResult, tile: int, scale: int, filters: List[TileCallback], + tile_generator: TileGenerator, overlap: float = 0.5, **kwargs, -) -> Image.Image: - width, height = kwargs.get("size", source.size if source else None) +) -> List[Image.Image]: + sources = stack.as_image() + + width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None) mask = kwargs.get("mask", None) noise_source = kwargs.get("noise_source", noise_source_histogram) fill_color = kwargs.get("fill_color", None) @@ -286,18 +252,9 @@ def process_tile_spiral( tile_mask = None tiles: List[Tuple[int, int, Image.Image]] = [] + tile_coords = tile_generator(width, height, tile, overlap) - # tile tuples is source, multiply by scale for dest - counter = 0 - tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap) - - if len(tile_coords) == 1: - single_tile = True - else: - single_tile = False - - for left, top in tile_coords: - counter += 1 + for counter, (left, top) in enumerate(tile_coords): logger.info( "processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top ) @@ -321,26 +278,31 @@ def process_tile_spiral( needs_margin = True bottom_margin = height - bottom - # if no source given, we don't have a source image - if not source: - tile_image = None - elif needs_margin: - # in the special case where the image is smaller than the specified tile size, just use the image - if single_tile: - logger.debug("creating and processing single-tile subtile") - tile_image = source - if mask: - tile_mask = mask - # otherwise use add histogram noise outside of the image border - else: - logger.debug( - "tiling and adding margins: %s, %s, %s, %s", - left_margin, - top_margin, - right_margin, - bottom_margin, - ) - base_image = source.crop( + if needs_margin: + logger.debug( + "tiling with added margins: %s, %s, %s, %s", + left_margin, + top_margin, + right_margin, + bottom_margin, + ) + tile_stack = add_margin( + stack, + left, + top, + right, + bottom, + left_margin, + top_margin, + right_margin, + bottom_margin, + tile, + noise_source, + fill_color, + ) + + if mask: + base_mask = mask.crop( ( left + left_margin, top + top_margin, @@ -348,43 +310,35 @@ def process_tile_spiral( bottom + bottom_margin, ) ) - tile_image = noise_source( - base_image, (tile, tile), (0, 0), fill=fill_color - ) - tile_image.paste(base_image, (left_margin, top_margin)) - - if mask: - base_mask = mask.crop( - ( - left + left_margin, - top + top_margin, - right + right_margin, - bottom + bottom_margin, - ) - ) - tile_mask = Image.new("L", (tile, tile), color=0) - tile_mask.paste(base_mask, (left_margin, top_margin)) + tile_mask = Image.new("L", (tile, tile), color=0) + tile_mask.paste(base_mask, (left_margin, top_margin)) else: logger.debug("tiling normally") - tile_image = source.crop((left, top, right, bottom)) + tile_stack = get_result_tile(stack, (left, top), Size(tile, tile)) if mask: tile_mask = mask.crop((left, top, right, bottom)) for image_filter in filters: - tile_image = image_filter(tile_image, tile_mask, (left, top, tile)) + tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile)) - tiles.append((left, top, tile_image)) + tiles.append((left, top, tile_stack)) - if single_tile: - return tile_image - else: - return blend_tiles(tiles, scale, width, height, tile, overlap) + lefts, tops, stacks = list(zip(*tiles)) + coords = list(zip(lefts, tops)) + stacks = list(zip(*stacks)) + + result = [] + for stack in stacks: + stack_tiles = zip(coords, stack) + result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap)) + + return result def process_tile_order( order: TileOrder, - source: Image.Image, + stack: StageResult, tile: int, scale: int, filters: List[TileCallback], @@ -395,13 +349,17 @@ def process_tile_order( """ if order == TileOrder.grid: logger.debug("using grid tile order with tile size: %s", tile) - return process_tile_grid(source, tile, scale, filters, **kwargs) + return process_tile_stack( + stack, tile, scale, filters, generate_tile_grid, **kwargs + ) elif order == TileOrder.kernel: logger.debug("using kernel tile order with tile size: %s", tile) raise NotImplementedError() elif order == TileOrder.spiral: logger.debug("using spiral tile order with tile size: %s", tile) - return process_tile_spiral(source, tile, scale, filters, **kwargs) + return process_tile_stack( + stack, tile, scale, filters, generate_tile_spiral, **kwargs + ) else: logger.warning("unknown tile order: %s", order) raise ValueError() @@ -495,3 +453,76 @@ def generate_tile_spiral( height_tile_target -= abs(state.value[1]) return tile_coords + + +def generate_tile_grid( + width: int, + height: int, + tile: int, + overlap: float = 0.0, +) -> List[Tuple[int, int]]: + adj_tile = int(float(tile) * (1.0 - overlap)) + tiles_x = ceil(width / adj_tile) + tiles_y = ceil(height / adj_tile) + total = tiles_x * tiles_y + logger.debug( + "processing %s tiles (%s x %s) with adjusted size of %s, %s overlap", + total, + tiles_x, + tiles_y, + adj_tile, + overlap, + ) + + tiles: List[Tuple[int, int, Image.Image]] = [] + + for y in range(tiles_y): + for x in range(tiles_x): + left = x * adj_tile + top = y * adj_tile + + tiles.append((int(left), int(top))) + + return tiles + + +def get_result_tile( + result: StageResult, + origin: Tuple[int, int], + tile: Size, +) -> List[Image.Image]: + top, left = origin + return [ + layer.crop((top, left, top + tile.height, left + tile.width)) + for layer in result.as_image() + ] + + +def add_margin( + stack: List[Image.Image], + left: int, + top: int, + right: int, + bottom: int, + left_margin: int, + top_margin: int, + right_margin: int, + bottom_margin: int, + tile: int, + noise_source, + fill_color, +) -> List[Image.Image]: + results = [] + for source in stack: + base_image = source.crop( + ( + left + left_margin, + top + top_margin, + right + right_margin, + bottom + bottom_margin, + ) + ) + tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color) + tile_image.paste(base_image, (left_margin, top_margin)) + + return results diff --git a/api/onnx_web/chain/upscale_bsrgan.py b/api/onnx_web/chain/upscale_bsrgan.py index d68c0042..6ade9580 100644 --- a/api/onnx_web/chain/upscale_bsrgan.py +++ b/api/onnx_web/chain/upscale_bsrgan.py @@ -79,12 +79,15 @@ class UpscaleBSRGANStage(BaseStage): logger.trace("BSRGAN input shape: %s", image.shape) scale = upscale.outscale - logger.trace("BSRGAN output shape: %s", ( + logger.trace( + "BSRGAN output shape: %s", + ( image.shape[0], image.shape[1], image.shape[2] * scale, image.shape[3] * scale, - )) + ), + ) output = bsrgan(image) diff --git a/api/onnx_web/chain/upscale_simple.py b/api/onnx_web/chain/upscale_simple.py index 0ec0499c..7e939bd4 100644 --- a/api/onnx_web/chain/upscale_simple.py +++ b/api/onnx_web/chain/upscale_simple.py @@ -38,10 +38,14 @@ class UpscaleSimpleStage(BaseStage): if method == "bilinear": logger.debug("using bilinear interpolation for highres") - outputs.append(source.resize(scaled_size, resample=Image.Resampling.BILINEAR)) + outputs.append( + source.resize(scaled_size, resample=Image.Resampling.BILINEAR) + ) elif method == "lanczos": logger.debug("using Lanczos interpolation for highres") - outputs.append(source.resize(scaled_size, resample=Image.Resampling.LANCZOS)) + outputs.append( + source.resize(scaled_size, resample=Image.Resampling.LANCZOS) + ) else: logger.warning("unknown upscaling method: %s", method) diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 317c73c5..e80734a0 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -105,7 +105,9 @@ def run_txt2img_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback() - images = chain.run(worker, server, params, StageResult.empty(), callback=progress, latents=latents) + images = chain.run( + worker, server, params, StageResult.empty(), callback=progress, latents=latents + ) _pairs, loras, inversions, _rest = parse_prompt(params) @@ -200,7 +202,9 @@ def run_img2img_pipeline( # run and append the filtered source progress = worker.get_progress_callback() - images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress) + images = chain.run( + worker, server, params, StageResult(images=[source]), callback=progress + ) if source_filter is not None and source_filter != "none": images.append(source) @@ -380,7 +384,9 @@ def run_inpaint_pipeline( # run and save latents = get_latents_from_seed(params.seed, size, batch=params.batch) progress = worker.get_progress_callback() - images = chain.run(worker, server, params, [source], callback=progress, latents=latents) + images = chain.run( + worker, server, params, [source], callback=progress, latents=latents + ) _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): @@ -455,7 +461,9 @@ def run_upscale_pipeline( # run and save progress = worker.get_progress_callback() - images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress) + images = chain.run( + worker, server, params, StageResult(images=[source]), callback=progress + ) _pairs, loras, inversions, _rest = parse_prompt(params) for image, output in zip(images, outputs): @@ -505,7 +513,9 @@ def run_blend_pipeline( # run and save progress = worker.get_progress_callback() - images = chain.run(worker, server, params, StageResult(images=sources), callback=progress) + images = chain.run( + worker, server, params, StageResult(images=sources), callback=progress + ) for image, output in zip(images, outputs): dest = save_image(server, output, image, params, size, upscale=upscale) diff --git a/api/tests/chain/test_source_noise.py b/api/tests/chain/test_source_noise.py index f43a8f86..37c99bfa 100644 --- a/api/tests/chain/test_source_noise.py +++ b/api/tests/chain/test_source_noise.py @@ -1,6 +1,6 @@ import unittest -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import StageResult from onnx_web.chain.source_noise import SourceNoiseStage from onnx_web.image.noise_source import noise_source_fill_edge from onnx_web.params import HighresParams, Size, UpscaleParams diff --git a/api/tests/chain/test_source_s3.py b/api/tests/chain/test_source_s3.py index 9b1e11ea..59bbb72f 100644 --- a/api/tests/chain/test_source_s3.py +++ b/api/tests/chain/test_source_s3.py @@ -1,6 +1,6 @@ import unittest -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import StageResult from onnx_web.chain.source_s3 import SourceS3Stage from onnx_web.params import HighresParams, Size, UpscaleParams diff --git a/api/tests/chain/test_source_url.py b/api/tests/chain/test_source_url.py index fe7588c7..4d03dedb 100644 --- a/api/tests/chain/test_source_url.py +++ b/api/tests/chain/test_source_url.py @@ -1,6 +1,6 @@ import unittest -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import StageResult from onnx_web.chain.source_url import SourceURLStage from onnx_web.params import HighresParams, Size, UpscaleParams diff --git a/api/tests/chain/test_tile.py b/api/tests/chain/test_tile.py index 7f599db2..6323c0bb 100644 --- a/api/tests/chain/test_tile.py +++ b/api/tests/chain/test_tile.py @@ -4,11 +4,11 @@ from PIL import Image from onnx_web.chain.tile import ( complete_tile, + generate_tile_grid, generate_tile_spiral, - get_tile_grads, + make_tile_grads, needs_tile, - process_tile_grid, - process_tile_spiral, + process_tile_stack, ) from onnx_web.params import Size @@ -59,24 +59,46 @@ class TestNeedsTile(unittest.TestCase): class TestTileGrads(unittest.TestCase): def test_center_tile(self): - grad_x, grad_y = get_tile_grads(32, 32, 8, 64, 64) + grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64) self.assertEqual(grad_x, [0, 1, 1, 0]) self.assertEqual(grad_y, [0, 1, 1, 0]) def test_vertical_edge_tile(self): - grad_x, grad_y = get_tile_grads(32, 0, 8, 64, 8) + grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8) self.assertEqual(grad_x, [0, 1, 1, 0]) self.assertEqual(grad_y, [1, 1, 1, 1]) def test_horizontal_edge_tile(self): - grad_x, grad_y = get_tile_grads(0, 32, 8, 8, 64) + grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64) self.assertEqual(grad_x, [1, 1, 1, 1]) self.assertEqual(grad_y, [0, 1, 1, 0]) +class TestGenerateTileGrid(unittest.TestCase): + def test_grid_complete(self): + tiles = generate_tile_grid(16, 16, 8, 0.0) + + self.assertEqual(len(tiles), 4) + self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)]) + + def test_grid_no_overlap(self): + tiles = generate_tile_grid(64, 64, 8, 0.0) + + self.assertEqual(len(tiles), 64) + self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)]) + self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)]) + + def test_grid_50_overlap(self): + tiles = generate_tile_grid(64, 64, 8, 0.5) + + self.assertEqual(len(tiles), 225) + self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)]) + self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) + + class TestGenerateTileSpiral(unittest.TestCase): def test_spiral_complete(self): tiles = generate_tile_spiral(16, 16, 8, 0.0) @@ -99,29 +121,15 @@ class TestGenerateTileSpiral(unittest.TestCase): self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)]) -class TestProcessTileGrid(unittest.TestCase): +class TestProcessTileStack(unittest.TestCase): def test_grid_full(self): source = Image.new("RGB", (64, 64)) - blend = process_tile_grid(source, 32, 1, []) + blend = process_tile_stack(source, 32, 1, []) self.assertEqual(blend.size, (64, 64)) def test_grid_partial(self): source = Image.new("RGB", (72, 72)) - blend = process_tile_grid(source, 32, 1, []) - - self.assertEqual(blend.size, (72, 72)) - - -class TestProcessTileSpiral(unittest.TestCase): - def test_grid_full(self): - source = Image.new("RGB", (64, 64)) - blend = process_tile_spiral(source, 32, 1, []) - - self.assertEqual(blend.size, (64, 64)) - - def test_grid_partial(self): - source = Image.new("RGB", (72, 72)) - blend = process_tile_spiral(source, 32, 1, []) + blend = process_tile_stack(source, 32, 1, []) self.assertEqual(blend.size, (72, 72)) diff --git a/api/tests/chain/test_upscale_highres.py b/api/tests/chain/test_upscale_highres.py index 8789e447..72437fc8 100644 --- a/api/tests/chain/test_upscale_highres.py +++ b/api/tests/chain/test_upscale_highres.py @@ -1,6 +1,6 @@ import unittest -from onnx_web.chain.result import StageResult +from onnx_web.chain.result import StageResult from onnx_web.chain.upscale_highres import UpscaleHighresStage from onnx_web.params import HighresParams, UpscaleParams