feat(api): pass tile order to inpaint and outpaint pipelines
This commit is contained in:
parent
51651abd08
commit
3a290822eb
|
@ -9,7 +9,7 @@ from ..device_pool import JobContext
|
|||
from ..output import save_image
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_grid
|
||||
from .utils import process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -100,8 +100,12 @@ class ChainPipeline:
|
|||
|
||||
return tile
|
||||
|
||||
image = process_tile_grid(
|
||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile]
|
||||
image = process_tile_order(
|
||||
stage_params.tile_order,
|
||||
image,
|
||||
stage_params.tile_size,
|
||||
stage_params.outscale,
|
||||
[stage_tile],
|
||||
)
|
||||
else:
|
||||
logger.info("image within tile size, running stage")
|
||||
|
|
|
@ -12,7 +12,7 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram
|
|||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_grid
|
||||
from .utils import process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -101,7 +101,9 @@ def blend_inpaint(
|
|||
|
||||
return result.images[0]
|
||||
|
||||
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
|
||||
output = process_tile_order(
|
||||
stage.tile_order, source_image, SizeChart.auto, 1, [outpaint]
|
||||
)
|
||||
|
||||
logger.info("final output image size", output.size)
|
||||
return output
|
||||
|
|
|
@ -10,9 +10,9 @@ from ..device_pool import JobContext
|
|||
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||
from ..output import save_image
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams, TileOrder
|
||||
from ..utils import ServerContext, is_debug
|
||||
from .utils import process_tile_grid, process_tile_spiral
|
||||
from .utils import process_tile_grid, process_tile_order
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -120,8 +120,13 @@ def upscale_outpaint(
|
|||
"outpainting with an even border, using spiral tiling with %s overlap",
|
||||
overlap,
|
||||
)
|
||||
output = process_tile_spiral(
|
||||
source_image, SizeChart.auto, 1, [outpaint], overlap=overlap
|
||||
output = process_tile_order(
|
||||
stage.tile_order,
|
||||
source_image,
|
||||
SizeChart.auto,
|
||||
1,
|
||||
[outpaint],
|
||||
overlap=overlap,
|
||||
)
|
||||
else:
|
||||
logger.debug("outpainting with an uneven border, using grid tiling")
|
||||
|
|
|
@ -3,6 +3,8 @@ from typing import List, Protocol, Tuple
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import TileOrder
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -16,6 +18,7 @@ def process_tile_grid(
|
|||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
width, height = source.size
|
||||
image = Image.new("RGB", (width * scale, height * scale))
|
||||
|
@ -46,6 +49,7 @@ def process_tile_spiral(
|
|||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
overlap: float = 0.5,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
if scale != 1:
|
||||
raise Exception("unsupported scale")
|
||||
|
@ -87,3 +91,22 @@ def process_tile_spiral(
|
|||
image.paste(tile_image, (left * scale, top * scale))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def process_tile_order(
|
||||
order: TileOrder,
|
||||
source: Image.Image,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
if order == TileOrder.grid:
|
||||
logger.debug("using grid tile order with tile size: %s", tile)
|
||||
return process_tile_grid(source, tile, scale, filters, **kwargs)
|
||||
elif order == TileOrder.kernel:
|
||||
logger.debug("using kernel tile order with tile size: %s", tile)
|
||||
raise NotImplementedError()
|
||||
elif order == TileOrder.spiral:
|
||||
logger.debug("using spiral tile order with tile size: %s", tile)
|
||||
return process_tile_spiral(source, tile, scale, filters, **kwargs)
|
||||
|
|
|
@ -151,10 +151,11 @@ def run_inpaint_pipeline(
|
|||
mask_filter: Any,
|
||||
strength: float,
|
||||
fill_color: str,
|
||||
tile_order: str,
|
||||
) -> None:
|
||||
# device = job.get_device()
|
||||
# progress = job.get_progress_callback()
|
||||
stage = StageParams()
|
||||
stage = StageParams(tile_order=tile_order)
|
||||
|
||||
image = upscale_outpaint(
|
||||
job,
|
||||
|
|
|
@ -14,6 +14,12 @@ class SizeChart(IntEnum):
|
|||
hd64k = 2**16
|
||||
|
||||
|
||||
class TileOrder:
|
||||
grid = "grid"
|
||||
kernel = "kernel"
|
||||
spiral = "spiral"
|
||||
|
||||
|
||||
Param = Union[str, int, float]
|
||||
Point = Tuple[int, int]
|
||||
|
||||
|
@ -122,13 +128,15 @@ class StageParams:
|
|||
def __init__(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
tile_size: int = SizeChart.auto,
|
||||
outscale: int = 1,
|
||||
tile_order: str = TileOrder.grid,
|
||||
tile_size: int = SizeChart.auto,
|
||||
# batch_size: int = 1,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.tile_size = tile_size
|
||||
self.outscale = outscale
|
||||
self.tile_order = tile_order
|
||||
self.tile_size = tile_size
|
||||
|
||||
|
||||
class UpscaleParams:
|
||||
|
|
|
@ -64,7 +64,7 @@ from .image import ( # mask filters; noise sources
|
|||
noise_source_uniform,
|
||||
)
|
||||
from .output import json_params, make_output_name
|
||||
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
||||
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams, TileOrder
|
||||
from .utils import (
|
||||
ServerContext,
|
||||
base_join,
|
||||
|
@ -589,6 +589,7 @@ def inpaint():
|
|||
get_config_value("strength", "max"),
|
||||
get_config_value("strength", "min"),
|
||||
)
|
||||
tile_order = get_from_list(request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral])
|
||||
|
||||
output = make_output_name(
|
||||
context,
|
||||
|
@ -604,6 +605,7 @@ def inpaint():
|
|||
noise_source.__name__,
|
||||
strength,
|
||||
fill_color,
|
||||
tile_order,
|
||||
),
|
||||
)
|
||||
logger.info("inpaint job queued for: %s", output)
|
||||
|
@ -625,6 +627,7 @@ def inpaint():
|
|||
mask_filter,
|
||||
strength,
|
||||
fill_color,
|
||||
tile_order,
|
||||
needs_device=device,
|
||||
)
|
||||
|
||||
|
|
|
@ -66,10 +66,6 @@
|
|||
"default": "histogram",
|
||||
"keys": []
|
||||
},
|
||||
"order": {
|
||||
"default": "spiral",
|
||||
"keys": []
|
||||
},
|
||||
"outscale": {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
|
@ -118,6 +114,10 @@
|
|||
"max": 1,
|
||||
"step": 0.01
|
||||
},
|
||||
"tileOrder": {
|
||||
"default": "spiral",
|
||||
"keys": []
|
||||
},
|
||||
"top": {
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { Box, Button, FormControl, FormControlLabel, InputLabel, Select, Stack } from '@mui/material';
|
||||
import { Box, Button, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
import { useMutation, useQuery, useQueryClient } from 'react-query';
|
||||
|
@ -161,7 +161,7 @@ export function Inpaint() {
|
|||
tileOrder: e.target.value,
|
||||
});
|
||||
}}
|
||||
></Select>
|
||||
>{['grid', 'kernel', 'spiral'].map((name) => <MenuItem key={name} value={name}>{name}</MenuItem>)}</Select>
|
||||
</FormControl>
|
||||
<Stack direction='row' spacing={2}>
|
||||
<FormControlLabel
|
||||
|
|
Loading…
Reference in New Issue