From 3a290822eb3a5038753d1643a93c0f6acc247387 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 11 Feb 2023 18:00:18 -0600 Subject: [PATCH] feat(api): pass tile order to inpaint and outpaint pipelines --- api/onnx_web/chain/base.py | 10 +++++++--- api/onnx_web/chain/blend_inpaint.py | 6 ++++-- api/onnx_web/chain/upscale_outpaint.py | 13 +++++++++---- api/onnx_web/chain/utils.py | 23 +++++++++++++++++++++++ api/onnx_web/diffusion/run.py | 3 ++- api/onnx_web/params.py | 12 ++++++++++-- api/onnx_web/serve.py | 5 ++++- api/params.json | 8 ++++---- gui/src/components/tab/Inpaint.tsx | 4 ++-- 9 files changed, 65 insertions(+), 19 deletions(-) diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 8e7a6487..6aa36071 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -9,7 +9,7 @@ from ..device_pool import JobContext from ..output import save_image from ..params import ImageParams, StageParams from ..utils import ServerContext, is_debug -from .utils import process_tile_grid +from .utils import process_tile_order logger = getLogger(__name__) @@ -100,8 +100,12 @@ class ChainPipeline: return tile - image = process_tile_grid( - image, stage_params.tile_size, stage_params.outscale, [stage_tile] + image = process_tile_order( + stage_params.tile_order, + image, + stage_params.tile_size, + stage_params.outscale, + [stage_tile], ) else: logger.info("image within tile size, running stage") diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index a980e5a5..350bdfe0 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -12,7 +12,7 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..utils import ServerContext, is_debug -from .utils import process_tile_grid +from .utils import process_tile_order logger = getLogger(__name__) @@ -101,7 +101,9 @@ def blend_inpaint( return result.images[0] - output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) + output = process_tile_order( + stage.tile_order, source_image, SizeChart.auto, 1, [outpaint] + ) logger.info("final output image size", output.size) return output diff --git a/api/onnx_web/chain/upscale_outpaint.py b/api/onnx_web/chain/upscale_outpaint.py index ffd55b02..2f64f62c 100644 --- a/api/onnx_web/chain/upscale_outpaint.py +++ b/api/onnx_web/chain/upscale_outpaint.py @@ -10,9 +10,9 @@ from ..device_pool import JobContext from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline from ..image import expand_image, mask_filter_none, noise_source_histogram from ..output import save_image -from ..params import Border, ImageParams, Size, SizeChart, StageParams +from ..params import Border, ImageParams, Size, SizeChart, StageParams, TileOrder from ..utils import ServerContext, is_debug -from .utils import process_tile_grid, process_tile_spiral +from .utils import process_tile_grid, process_tile_order logger = getLogger(__name__) @@ -120,8 +120,13 @@ def upscale_outpaint( "outpainting with an even border, using spiral tiling with %s overlap", overlap, ) - output = process_tile_spiral( - source_image, SizeChart.auto, 1, [outpaint], overlap=overlap + output = process_tile_order( + stage.tile_order, + source_image, + SizeChart.auto, + 1, + [outpaint], + overlap=overlap, ) else: logger.debug("outpainting with an uneven border, using grid tiling") diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index 1bb3ac22..053a093f 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -3,6 +3,8 @@ from typing import List, Protocol, Tuple from PIL import Image +from ..params import TileOrder + logger = getLogger(__name__) @@ -16,6 +18,7 @@ def process_tile_grid( tile: int, scale: int, filters: List[TileCallback], + **kwargs, ) -> Image.Image: width, height = source.size image = Image.new("RGB", (width * scale, height * scale)) @@ -46,6 +49,7 @@ def process_tile_spiral( scale: int, filters: List[TileCallback], overlap: float = 0.5, + **kwargs, ) -> Image.Image: if scale != 1: raise Exception("unsupported scale") @@ -87,3 +91,22 @@ def process_tile_spiral( image.paste(tile_image, (left * scale, top * scale)) return image + + +def process_tile_order( + order: TileOrder, + source: Image.Image, + tile: int, + scale: int, + filters: List[TileCallback], + **kwargs, +) -> Image.Image: + if order == TileOrder.grid: + logger.debug("using grid tile order with tile size: %s", tile) + return process_tile_grid(source, tile, scale, filters, **kwargs) + elif order == TileOrder.kernel: + logger.debug("using kernel tile order with tile size: %s", tile) + raise NotImplementedError() + elif order == TileOrder.spiral: + logger.debug("using spiral tile order with tile size: %s", tile) + return process_tile_spiral(source, tile, scale, filters, **kwargs) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index e8924b94..1c96bfd7 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -151,10 +151,11 @@ def run_inpaint_pipeline( mask_filter: Any, strength: float, fill_color: str, + tile_order: str, ) -> None: # device = job.get_device() # progress = job.get_progress_callback() - stage = StageParams() + stage = StageParams(tile_order=tile_order) image = upscale_outpaint( job, diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index f81e61c9..a6144c24 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -14,6 +14,12 @@ class SizeChart(IntEnum): hd64k = 2**16 +class TileOrder: + grid = "grid" + kernel = "kernel" + spiral = "spiral" + + Param = Union[str, int, float] Point = Tuple[int, int] @@ -122,13 +128,15 @@ class StageParams: def __init__( self, name: Optional[str] = None, - tile_size: int = SizeChart.auto, outscale: int = 1, + tile_order: str = TileOrder.grid, + tile_size: int = SizeChart.auto, # batch_size: int = 1, ) -> None: self.name = name - self.tile_size = tile_size self.outscale = outscale + self.tile_order = tile_order + self.tile_size = tile_size class UpscaleParams: diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 67e049e6..79fdc57b 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -64,7 +64,7 @@ from .image import ( # mask filters; noise sources noise_source_uniform, ) from .output import json_params, make_output_name -from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams +from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams, TileOrder from .utils import ( ServerContext, base_join, @@ -589,6 +589,7 @@ def inpaint(): get_config_value("strength", "max"), get_config_value("strength", "min"), ) + tile_order = get_from_list(request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral]) output = make_output_name( context, @@ -604,6 +605,7 @@ def inpaint(): noise_source.__name__, strength, fill_color, + tile_order, ), ) logger.info("inpaint job queued for: %s", output) @@ -625,6 +627,7 @@ def inpaint(): mask_filter, strength, fill_color, + tile_order, needs_device=device, ) diff --git a/api/params.json b/api/params.json index 6f194a0e..f1a9824d 100644 --- a/api/params.json +++ b/api/params.json @@ -66,10 +66,6 @@ "default": "histogram", "keys": [] }, - "order": { - "default": "spiral", - "keys": [] - }, "outscale": { "default": 1, "min": 1, @@ -118,6 +114,10 @@ "max": 1, "step": 0.01 }, + "tileOrder": { + "default": "spiral", + "keys": [] + }, "top": { "default": 0, "min": 0, diff --git a/gui/src/components/tab/Inpaint.tsx b/gui/src/components/tab/Inpaint.tsx index bfd9983f..2098857d 100644 --- a/gui/src/components/tab/Inpaint.tsx +++ b/gui/src/components/tab/Inpaint.tsx @@ -1,5 +1,5 @@ import { doesExist, mustExist } from '@apextoaster/js-utils'; -import { Box, Button, FormControl, FormControlLabel, InputLabel, Select, Stack } from '@mui/material'; +import { Box, Button, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material'; import * as React from 'react'; import { useContext } from 'react'; import { useMutation, useQuery, useQueryClient } from 'react-query'; @@ -161,7 +161,7 @@ export function Inpaint() { tileOrder: e.target.value, }); }} - > + >{['grid', 'kernel', 'spiral'].map((name) => {name})}