1
0
Fork 0

feat(api): implement spiral grid for outpainting

This commit is contained in:
Sean Sube 2023-01-28 23:46:36 -06:00
parent 680adc70ea
commit a4d3f18a48
6 changed files with 60 additions and 15 deletions

View File

@ -14,7 +14,7 @@ from ..utils import (
ServerContext, ServerContext,
) )
from .utils import ( from .utils import (
process_tiles, process_tile_grid,
) )
logger = getLogger(__name__) logger = getLogger(__name__)
@ -86,7 +86,7 @@ class ChainPipeline:
return tile return tile
image = process_tiles( image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile]) image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else: else:
logger.info('source image within tile size, running stage') logger.info('source image within tile size, running stage')

View File

@ -27,7 +27,7 @@ from ..utils import (
ServerContext, ServerContext,
) )
from .utils import ( from .utils import (
process_tiles, process_tile_grid,
) )
import numpy as np import numpy as np
@ -98,7 +98,7 @@ def blend_inpaint(
) )
return result.images[0] return result.images[0]
output = process_tiles(source_image, SizeChart.auto, 1, [outpaint]) output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
logger.info('final output image size', output.size) logger.info('final output image size', output.size)
return output return output

View File

@ -27,7 +27,7 @@ from ..utils import (
ServerContext, ServerContext,
) )
from .utils import ( from .utils import (
process_tiles, process_tile_spiral,
) )
import numpy as np import numpy as np
@ -98,7 +98,7 @@ def upscale_outpaint(
) )
return result.images[0] return result.images[0]
output = process_tiles(source_image, SizeChart.auto, 1, [outpaint]) output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
logger.info('final output image size: %sx%s', output.width, output.height) logger.info('final output image size: %sx%s', output.width, output.height)
return output return output

View File

@ -10,7 +10,7 @@ class TileCallback(Protocol):
pass pass
def process_tiles( def process_tile_grid(
source: Image.Image, source: Image.Image,
tile: int, tile: int,
scale: int, scale: int,
@ -37,3 +37,47 @@ def process_tiles(
image.paste(tile_image, (left * scale, top * scale)) image.paste(tile_image, (left * scale, top * scale))
return image return image
def process_tile_spiral(
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
overlap: float = 0.5,
) -> Image.Image:
if scale != 1:
raise Exception('unsupported scale')
width, height = source.size
image = Image.new('RGB', (width * scale, height * scale))
image.paste(source, (0, 0))
# TODO: only valid for overlap = 0.5
if overlap == 0.5:
tiles = [
(0, tile * -overlap),
(tile * overlap, tile * -overlap),
(tile * overlap, 0),
(tile * overlap, tile * overlap),
(0, tile * overlap),
(tile * -overlap, tile * -overlap),
(tile * -overlap, 0),
(tile * -overlap, tile * overlap),
]
# tile tuples is source, multiply by scale for dest
counter = 0
for left, top in tiles:
logger.info('processing tile %s of %s', counter, len(tiles))
counter += 1
# TODO: only valid for scale == 1, resize source for others
tile_image = image.crop((left, top, left + tile, top + tile))
for filter in filters:
tile_image = filter(tile_image, (left, top, tile))
image.paste(tile_image, (left * scale, top * scale))
return image

View File

@ -97,15 +97,16 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
if val is None: if val is None:
return SizeChart.auto return SizeChart.auto
if type(val) is str:
if val in SizeChart:
return SizeChart[val]
else:
return int(val)
if type(val) is int: if type(val) is int:
return val return val
if type(val) is str:
for size in SizeChart:
if val == size.name:
return size
return int(val)
raise Exception('invalid size') raise Exception('invalid size')

View File

@ -30,7 +30,7 @@
"name": "save-local", "name": "save-local",
"type": "persist-disk", "type": "persist-disk",
"params": { "params": {
"tile_size": "8k" "tile_size": "hd8k"
} }
}, },
{ {
@ -40,7 +40,7 @@
"bucket": "storage-stable-diffusion", "bucket": "storage-stable-diffusion",
"endpoint_url": "http://scylla.home.holdmyran.ch:8000", "endpoint_url": "http://scylla.home.holdmyran.ch:8000",
"profile_name": "ceph", "profile_name": "ceph",
"tile_size": "8k" "tile_size": "hd8k"
} }
} }
] ]