feat(api): pass tile order to inpaint and outpaint pipelines
This commit is contained in:
parent
51651abd08
commit
3a290822eb
|
@ -9,7 +9,7 @@ from ..device_pool import JobContext
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import ImageParams, StageParams
|
from ..params import ImageParams, StageParams
|
||||||
from ..utils import ServerContext, is_debug
|
from ..utils import ServerContext, is_debug
|
||||||
from .utils import process_tile_grid
|
from .utils import process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -100,8 +100,12 @@ class ChainPipeline:
|
||||||
|
|
||||||
return tile
|
return tile
|
||||||
|
|
||||||
image = process_tile_grid(
|
image = process_tile_order(
|
||||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile]
|
stage_params.tile_order,
|
||||||
|
image,
|
||||||
|
stage_params.tile_size,
|
||||||
|
stage_params.outscale,
|
||||||
|
[stage_tile],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("image within tile size, running stage")
|
logger.info("image within tile size, running stage")
|
||||||
|
|
|
@ -12,7 +12,7 @@ from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
||||||
from ..utils import ServerContext, is_debug
|
from ..utils import ServerContext, is_debug
|
||||||
from .utils import process_tile_grid
|
from .utils import process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -101,7 +101,9 @@ def blend_inpaint(
|
||||||
|
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
|
|
||||||
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
|
output = process_tile_order(
|
||||||
|
stage.tile_order, source_image, SizeChart.auto, 1, [outpaint]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("final output image size", output.size)
|
logger.info("final output image size", output.size)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -10,9 +10,9 @@ from ..device_pool import JobContext
|
||||||
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
from ..diffusion.load import get_latents_from_seed, get_tile_latents, load_pipeline
|
||||||
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
from ..image import expand_image, mask_filter_none, noise_source_histogram
|
||||||
from ..output import save_image
|
from ..output import save_image
|
||||||
from ..params import Border, ImageParams, Size, SizeChart, StageParams
|
from ..params import Border, ImageParams, Size, SizeChart, StageParams, TileOrder
|
||||||
from ..utils import ServerContext, is_debug
|
from ..utils import ServerContext, is_debug
|
||||||
from .utils import process_tile_grid, process_tile_spiral
|
from .utils import process_tile_grid, process_tile_order
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -120,8 +120,13 @@ def upscale_outpaint(
|
||||||
"outpainting with an even border, using spiral tiling with %s overlap",
|
"outpainting with an even border, using spiral tiling with %s overlap",
|
||||||
overlap,
|
overlap,
|
||||||
)
|
)
|
||||||
output = process_tile_spiral(
|
output = process_tile_order(
|
||||||
source_image, SizeChart.auto, 1, [outpaint], overlap=overlap
|
stage.tile_order,
|
||||||
|
source_image,
|
||||||
|
SizeChart.auto,
|
||||||
|
1,
|
||||||
|
[outpaint],
|
||||||
|
overlap=overlap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("outpainting with an uneven border, using grid tiling")
|
logger.debug("outpainting with an uneven border, using grid tiling")
|
||||||
|
|
|
@ -3,6 +3,8 @@ from typing import List, Protocol, Tuple
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from ..params import TileOrder
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +18,7 @@ def process_tile_grid(
|
||||||
tile: int,
|
tile: int,
|
||||||
scale: int,
|
scale: int,
|
||||||
filters: List[TileCallback],
|
filters: List[TileCallback],
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
width, height = source.size
|
width, height = source.size
|
||||||
image = Image.new("RGB", (width * scale, height * scale))
|
image = Image.new("RGB", (width * scale, height * scale))
|
||||||
|
@ -46,6 +49,7 @@ def process_tile_spiral(
|
||||||
scale: int,
|
scale: int,
|
||||||
filters: List[TileCallback],
|
filters: List[TileCallback],
|
||||||
overlap: float = 0.5,
|
overlap: float = 0.5,
|
||||||
|
**kwargs,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
if scale != 1:
|
if scale != 1:
|
||||||
raise Exception("unsupported scale")
|
raise Exception("unsupported scale")
|
||||||
|
@ -87,3 +91,22 @@ def process_tile_spiral(
|
||||||
image.paste(tile_image, (left * scale, top * scale))
|
image.paste(tile_image, (left * scale, top * scale))
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def process_tile_order(
|
||||||
|
order: TileOrder,
|
||||||
|
source: Image.Image,
|
||||||
|
tile: int,
|
||||||
|
scale: int,
|
||||||
|
filters: List[TileCallback],
|
||||||
|
**kwargs,
|
||||||
|
) -> Image.Image:
|
||||||
|
if order == TileOrder.grid:
|
||||||
|
logger.debug("using grid tile order with tile size: %s", tile)
|
||||||
|
return process_tile_grid(source, tile, scale, filters, **kwargs)
|
||||||
|
elif order == TileOrder.kernel:
|
||||||
|
logger.debug("using kernel tile order with tile size: %s", tile)
|
||||||
|
raise NotImplementedError()
|
||||||
|
elif order == TileOrder.spiral:
|
||||||
|
logger.debug("using spiral tile order with tile size: %s", tile)
|
||||||
|
return process_tile_spiral(source, tile, scale, filters, **kwargs)
|
||||||
|
|
|
@ -151,10 +151,11 @@ def run_inpaint_pipeline(
|
||||||
mask_filter: Any,
|
mask_filter: Any,
|
||||||
strength: float,
|
strength: float,
|
||||||
fill_color: str,
|
fill_color: str,
|
||||||
|
tile_order: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# device = job.get_device()
|
# device = job.get_device()
|
||||||
# progress = job.get_progress_callback()
|
# progress = job.get_progress_callback()
|
||||||
stage = StageParams()
|
stage = StageParams(tile_order=tile_order)
|
||||||
|
|
||||||
image = upscale_outpaint(
|
image = upscale_outpaint(
|
||||||
job,
|
job,
|
||||||
|
|
|
@ -14,6 +14,12 @@ class SizeChart(IntEnum):
|
||||||
hd64k = 2**16
|
hd64k = 2**16
|
||||||
|
|
||||||
|
|
||||||
|
class TileOrder:
|
||||||
|
grid = "grid"
|
||||||
|
kernel = "kernel"
|
||||||
|
spiral = "spiral"
|
||||||
|
|
||||||
|
|
||||||
Param = Union[str, int, float]
|
Param = Union[str, int, float]
|
||||||
Point = Tuple[int, int]
|
Point = Tuple[int, int]
|
||||||
|
|
||||||
|
@ -122,13 +128,15 @@ class StageParams:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
tile_size: int = SizeChart.auto,
|
|
||||||
outscale: int = 1,
|
outscale: int = 1,
|
||||||
|
tile_order: str = TileOrder.grid,
|
||||||
|
tile_size: int = SizeChart.auto,
|
||||||
# batch_size: int = 1,
|
# batch_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.tile_size = tile_size
|
|
||||||
self.outscale = outscale
|
self.outscale = outscale
|
||||||
|
self.tile_order = tile_order
|
||||||
|
self.tile_size = tile_size
|
||||||
|
|
||||||
|
|
||||||
class UpscaleParams:
|
class UpscaleParams:
|
||||||
|
|
|
@ -64,7 +64,7 @@ from .image import ( # mask filters; noise sources
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
from .output import json_params, make_output_name
|
from .output import json_params, make_output_name
|
||||||
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams
|
from .params import Border, DeviceParams, ImageParams, Size, StageParams, UpscaleParams, TileOrder
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ServerContext,
|
ServerContext,
|
||||||
base_join,
|
base_join,
|
||||||
|
@ -589,6 +589,7 @@ def inpaint():
|
||||||
get_config_value("strength", "max"),
|
get_config_value("strength", "max"),
|
||||||
get_config_value("strength", "min"),
|
get_config_value("strength", "min"),
|
||||||
)
|
)
|
||||||
|
tile_order = get_from_list(request.args, "tileOrder", [TileOrder.grid, TileOrder.kernel, TileOrder.spiral])
|
||||||
|
|
||||||
output = make_output_name(
|
output = make_output_name(
|
||||||
context,
|
context,
|
||||||
|
@ -604,6 +605,7 @@ def inpaint():
|
||||||
noise_source.__name__,
|
noise_source.__name__,
|
||||||
strength,
|
strength,
|
||||||
fill_color,
|
fill_color,
|
||||||
|
tile_order,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
logger.info("inpaint job queued for: %s", output)
|
logger.info("inpaint job queued for: %s", output)
|
||||||
|
@ -625,6 +627,7 @@ def inpaint():
|
||||||
mask_filter,
|
mask_filter,
|
||||||
strength,
|
strength,
|
||||||
fill_color,
|
fill_color,
|
||||||
|
tile_order,
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -66,10 +66,6 @@
|
||||||
"default": "histogram",
|
"default": "histogram",
|
||||||
"keys": []
|
"keys": []
|
||||||
},
|
},
|
||||||
"order": {
|
|
||||||
"default": "spiral",
|
|
||||||
"keys": []
|
|
||||||
},
|
|
||||||
"outscale": {
|
"outscale": {
|
||||||
"default": 1,
|
"default": 1,
|
||||||
"min": 1,
|
"min": 1,
|
||||||
|
@ -118,6 +114,10 @@
|
||||||
"max": 1,
|
"max": 1,
|
||||||
"step": 0.01
|
"step": 0.01
|
||||||
},
|
},
|
||||||
|
"tileOrder": {
|
||||||
|
"default": "spiral",
|
||||||
|
"keys": []
|
||||||
|
},
|
||||||
"top": {
|
"top": {
|
||||||
"default": 0,
|
"default": 0,
|
||||||
"min": 0,
|
"min": 0,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||||
import { Box, Button, FormControl, FormControlLabel, InputLabel, Select, Stack } from '@mui/material';
|
import { Box, Button, FormControl, FormControlLabel, InputLabel, MenuItem, Select, Stack } from '@mui/material';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useContext } from 'react';
|
import { useContext } from 'react';
|
||||||
import { useMutation, useQuery, useQueryClient } from 'react-query';
|
import { useMutation, useQuery, useQueryClient } from 'react-query';
|
||||||
|
@ -161,7 +161,7 @@ export function Inpaint() {
|
||||||
tileOrder: e.target.value,
|
tileOrder: e.target.value,
|
||||||
});
|
});
|
||||||
}}
|
}}
|
||||||
></Select>
|
>{['grid', 'kernel', 'spiral'].map((name) => <MenuItem key={name} value={name}>{name}</MenuItem>)}</Select>
|
||||||
</FormControl>
|
</FormControl>
|
||||||
<Stack direction='row' spacing={2}>
|
<Stack direction='row' spacing={2}>
|
||||||
<FormControlLabel
|
<FormControlLabel
|
||||||
|
|
Loading…
Reference in New Issue