feat(api): implement spiral grid for outpainting
This commit is contained in:
parent
680adc70ea
commit
a4d3f18a48
|
@ -14,7 +14,7 @@ from ..utils import (
|
|||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
process_tiles,
|
||||
process_tile_grid,
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -86,7 +86,7 @@ class ChainPipeline:
|
|||
|
||||
return tile
|
||||
|
||||
image = process_tiles(
|
||||
image = process_tile_grid(
|
||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||
else:
|
||||
logger.info('source image within tile size, running stage')
|
||||
|
|
|
@ -27,7 +27,7 @@ from ..utils import (
|
|||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
process_tiles,
|
||||
process_tile_grid,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
@ -98,7 +98,7 @@ def blend_inpaint(
|
|||
)
|
||||
return result.images[0]
|
||||
|
||||
output = process_tiles(source_image, SizeChart.auto, 1, [outpaint])
|
||||
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
|
||||
|
||||
logger.info('final output image size', output.size)
|
||||
return output
|
||||
|
|
|
@ -27,7 +27,7 @@ from ..utils import (
|
|||
ServerContext,
|
||||
)
|
||||
from .utils import (
|
||||
process_tiles,
|
||||
process_tile_spiral,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
@ -98,7 +98,7 @@ def upscale_outpaint(
|
|||
)
|
||||
return result.images[0]
|
||||
|
||||
output = process_tiles(source_image, SizeChart.auto, 1, [outpaint])
|
||||
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
|
||||
|
||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||
return output
|
||||
|
|
|
@ -10,7 +10,7 @@ class TileCallback(Protocol):
|
|||
pass
|
||||
|
||||
|
||||
def process_tiles(
|
||||
def process_tile_grid(
|
||||
source: Image.Image,
|
||||
tile: int,
|
||||
scale: int,
|
||||
|
@ -37,3 +37,47 @@ def process_tiles(
|
|||
image.paste(tile_image, (left * scale, top * scale))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def process_tile_spiral(
|
||||
source: Image.Image,
|
||||
tile: int,
|
||||
scale: int,
|
||||
filters: List[TileCallback],
|
||||
overlap: float = 0.5,
|
||||
) -> Image.Image:
|
||||
if scale != 1:
|
||||
raise Exception('unsupported scale')
|
||||
|
||||
width, height = source.size
|
||||
image = Image.new('RGB', (width * scale, height * scale))
|
||||
image.paste(source, (0, 0))
|
||||
|
||||
# TODO: only valid for overlap = 0.5
|
||||
if overlap == 0.5:
|
||||
tiles = [
|
||||
(0, tile * -overlap),
|
||||
(tile * overlap, tile * -overlap),
|
||||
(tile * overlap, 0),
|
||||
(tile * overlap, tile * overlap),
|
||||
(0, tile * overlap),
|
||||
(tile * -overlap, tile * -overlap),
|
||||
(tile * -overlap, 0),
|
||||
(tile * -overlap, tile * overlap),
|
||||
]
|
||||
|
||||
# tile tuples is source, multiply by scale for dest
|
||||
counter = 0
|
||||
for left, top in tiles:
|
||||
logger.info('processing tile %s of %s', counter, len(tiles))
|
||||
counter += 1
|
||||
|
||||
# TODO: only valid for scale == 1, resize source for others
|
||||
tile_image = image.crop((left, top, left + tile, top + tile))
|
||||
|
||||
for filter in filters:
|
||||
tile_image = filter(tile_image, (left, top, tile))
|
||||
|
||||
image.paste(tile_image, (left * scale, top * scale))
|
||||
|
||||
return image
|
||||
|
|
|
@ -97,15 +97,16 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
|||
if val is None:
|
||||
return SizeChart.auto
|
||||
|
||||
if type(val) is str:
|
||||
if val in SizeChart:
|
||||
return SizeChart[val]
|
||||
else:
|
||||
return int(val)
|
||||
|
||||
if type(val) is int:
|
||||
return val
|
||||
|
||||
if type(val) is str:
|
||||
for size in SizeChart:
|
||||
if val == size.name:
|
||||
return size
|
||||
|
||||
return int(val)
|
||||
|
||||
raise Exception('invalid size')
|
||||
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
"name": "save-local",
|
||||
"type": "persist-disk",
|
||||
"params": {
|
||||
"tile_size": "8k"
|
||||
"tile_size": "hd8k"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -40,7 +40,7 @@
|
|||
"bucket": "storage-stable-diffusion",
|
||||
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
|
||||
"profile_name": "ceph",
|
||||
"tile_size": "8k"
|
||||
"tile_size": "hd8k"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue