1
0
Fork 0

feat(api): pass tile order to inpaint and outpaint pipelines

This commit is contained in:
Sean Sube 2023-02-11 18:00:18 -06:00
parent 51651abd08
commit 3a290822eb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
9 changed files with 65 additions and 19 deletions

View File

@ -9,7 +9,7 @@ from ..device_pool import JobContext
from ..output import save_image from ..output import save_image
from ..params import ImageParams, StageParams from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug from ..utils import ServerContext, is_debug
from .utils import process_tile_grid from .utils import process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
@ -100,8 +100,12 @@ class ChainPipeline:
return tile return tile
image = process_tile_grid( image = process_tile_order(
image, stage_params.tile_size, stage_params.outscale, [stage_tile] stage_params.tile_order,
image,
stage_params.tile_size,
stage_params.outscale,
[stage_tile],
) )
else: else:
logger.info("image within tile size, running stage") logger.info("image within tile size, running stage")

View File

@ -12,7 +12,7 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams
from ..utils import ServerContext, is_debug from ..utils import ServerContext, is_debug
from .utils import process_tile_grid from .utils import process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
@ -101,7 +101,9 @@ def blend_inpaint(
return result.images[0] return result.images[0]
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint]) output = process_tile_order(
stage.tile_order, source_image, SizeChart.auto, 1, [outpaint]
)
logger.info("final output image size", output.size) logger.info("final output image size", output.size)
return output return output

View File

@ -10,9 +10,9 @@ from ..device_pool import JobContext
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
from ..image import expand_image, mask_filter_none, noise_source_histogram from ..image import expand_image, mask_filter_none, noise_source_histogram
from ..output import save_image from ..output import save_image
from ..params import Border, ImageParams, Size, SizeChart, StageParams from ..params import Border, ImageParams, Size, SizeChart, StageParams, TileOrder
from ..utils import ServerContext, is_debug from ..utils import ServerContext, is_debug
from .utils import process_tile_grid, process_tile_spiral from .utils import process_tile_grid, process_tile_order
logger = getLogger(__name__) logger = getLogger(__name__)
@ -120,8 +120,13 @@ def upscale_outpaint(
"outpainting with an even border, using spiral tiling with %s overlap", "outpainting with an even border, using spiral tiling with %s overlap",
overlap, overlap,
) )
output = process_tile_spiral( output = process_tile_order(
source_image, SizeChart.auto, 1, [outpaint], overlap=overlap stage.tile_order,
source_image,
SizeChart.auto,
1,
[outpaint],
overlap=overlap,
) )
else: else:
logger.debug("outpainting with an uneven border, using grid tiling") logger.debug("outpainting with an uneven border, using grid tiling")

View File

@ -3,6 +3,8 @@ from typing import List, Protocol, Tuple
from PIL import Image from PIL import Image
from ..params import TileOrder
logger = getLogger(__name__) logger = getLogger(__name__)
@ -16,6 +18,7 @@ def process_tile_grid(
tile: int, tile: int,
scale: int, scale: int,
filters: List[TileCallback], filters: List[TileCallback],
**kwargs,
) -> Image.Image: ) -> Image.Image:
width, height = source.size width, height = source.size
image = Image.new("RGB", (width * scale, height * scale)) image = Image.new("RGB", (width * scale, height * scale))
@ -46,6 +49,7 @@ def process_tile_spiral(
scale: int, scale: int,
filters: List[TileCallback], filters: List[TileCallback],
overlap: float = 0.5, overlap: float = 0.5,
**kwargs,
) -> Image.Image: ) -> Image.Image:
if scale != 1: if scale != 1:
raise Exception("unsupported scale") raise Exception("unsupported scale")
@ -87,3 +91,22 @@ def process_tile_spiral(
image.paste(tile_image, (left * scale, top * scale)) image.paste(tile_image, (left * scale, top * scale))
return image return image
def process_tile_order(
order: TileOrder,
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
**kwargs,
) -> Image.Image:
if order == TileOrder.grid:
logger.debug("using grid tile order with tile size: %s", tile)
return process_tile_grid(source, tile, scale, filters, **kwargs)
elif order == TileOrder.kernel:
logger.debug("using kernel tile order with tile size: %s", tile)
raise NotImplementedError()
elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile)
return process_tile_spiral(source, tile, scale, filters, **kwargs)

View File

@ -151,10 +151,11 @@ def run_inpaint_pipeline(
mask_filter: Any, mask_filter: Any,
strength: float, strength: float,
fill_color: str, fill_color: str,
tile_order: str,
) -> None: ) -> None:
# device = job.get_device() # device = job.get_device()
# progress = job.get_progress_callback() # progress = job.get_progress_callback()
stage = StageParams() stage = StageParams(tile_order=tile_order)
image = upscale_outpaint( image = upscale_outpaint(
job, job,

View File

@ -14,6 +14,12 @@ class SizeChart(IntEnum):
hd64k = 2**16 hd64k = 2**16
class TileOrder:
grid = "grid"
kernel = "kernel"
spiral = "spiral"
Param = Union[str, int, float] Param = Union[str, int, float]
Point = Tuple[int, int] Point = Tuple[int, int]
@ -122,13 +128,15 @@ class StageParams:
def __init__( def __init__(
self, self,
name: Optional[str] = None, name: Optional[str] = None,
tile_size: int = SizeChart.auto,
outscale: int = 1, outscale: int = 1,
tile_order: str = TileOrder.grid,
tile_size: int = SizeChart.auto,
# batch_size: int = 1, # batch_size: int = 1,
) -> None: ) -> None:
self.name = name self.name = name
self.tile_size = tile_size
self.outscale = outscale self.outscale = outscale
self.tile_order = tile_order
self.tile_size = tile_size
class UpscaleParams: class UpscaleParams:

View File

@ -64,7 +64,7 @@ from .image import ( # mask filters; noise sources
noise_source_uniform, noise_source_uniform,
) )
from .output import json_params, make_output_name from .output import json_params, make_output_name
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams, TileOrder
from .utils import ( from .utils import (
ServerContext, ServerContext,
base_join, base_join,
@ -589,6 +589,7 @@ def inpaint():
get_config_value("strength", "max"), get_config_value("strength", "max"),
get_config_value("strength", "min"), get_config_value("strength", "min"),
) )
tile_order = get_from_list(request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral])
output = make_output_name( output = make_output_name(
context, context,
@ -604,6 +605,7 @@ def inpaint():
noise_source.__name__, noise_source.__name__,
strength, strength,
fill_color, fill_color,
tile_order,
), ),
) )
logger.info("inpaint job queued for: %s", output) logger.info("inpaint job queued for: %s", output)
@ -625,6 +627,7 @@ def inpaint():
mask_filter, mask_filter,
strength, strength,
fill_color, fill_color,
tile_order,
needs_device=device, needs_device=device,
) )

View File

@ -66,10 +66,6 @@
"default": "histogram", "default": "histogram",
"keys": [] "keys": []
}, },
"order": {
"default": "spiral",
"keys": []
},
"outscale": { "outscale": {
"default": 1, "default": 1,
"min": 1, "min": 1,
@ -118,6 +114,10 @@
"max": 1, "max": 1,
"step": 0.01 "step": 0.01
}, },
"tileOrder": {
"default": "spiral",
"keys": []
},
"top": { "top": {
"default": 0, "default": 0,
"min": 0, "min": 0,

View File

@ -1,5 +1,5 @@
import { doesExist, mustExist } from '@apextoaster/js-utils'; import { doesExist, mustExist } from '@apextoaster/js-utils';
import { Box, Button, FormControl, FormControlLabel, InputLabel, Select, Stack } from '@mui/material'; import { Box, Button, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useContext } from 'react'; import { useContext } from 'react';
import { useMutation, useQuery, useQueryClient } from 'react-query'; import { useMutation, useQuery, useQueryClient } from 'react-query';
@ -161,7 +161,7 @@ export function Inpaint() {
tileOrder: e.target.value, tileOrder: e.target.value,
}); });
}} }}
></Select> >{['grid', 'kernel', 'spiral'].map((name) => <MenuItem key={name} value={name}>{name}</MenuItem>)}</Select>
</FormControl> </FormControl>
<Stack direction='row' spacing={2}> <Stack direction='row' spacing={2}>
<FormControlLabel <FormControlLabel