feat(api): implement spiral grid for outpainting
This commit is contained in:
parent
680adc70ea
commit
a4d3f18a48
|
@ -14,7 +14,7 @@ from ..utils import (
|
||||||
ServerContext,
|
ServerContext,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
process_tiles,
|
process_tile_grid,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -86,7 +86,7 @@ class ChainPipeline:
|
||||||
|
|
||||||
return tile
|
return tile
|
||||||
|
|
||||||
image = process_tiles(
|
image = process_tile_grid(
|
||||||
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
|
||||||
else:
|
else:
|
||||||
logger.info('source image within tile size, running stage')
|
logger.info('source image within tile size, running stage')
|
||||||
|
|
|
@ -27,7 +27,7 @@ from ..utils import (
|
||||||
ServerContext,
|
ServerContext,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
process_tiles,
|
process_tile_grid,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -98,7 +98,7 @@ def blend_inpaint(
|
||||||
)
|
)
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
|
|
||||||
output = process_tiles(source_image, SizeChart.auto, 1, [outpaint])
|
output = process_tile_grid(source_image, SizeChart.auto, 1, [outpaint])
|
||||||
|
|
||||||
logger.info('final output image size', output.size)
|
logger.info('final output image size', output.size)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -27,7 +27,7 @@ from ..utils import (
|
||||||
ServerContext,
|
ServerContext,
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
process_tiles,
|
process_tile_spiral,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -98,7 +98,7 @@ def upscale_outpaint(
|
||||||
)
|
)
|
||||||
return result.images[0]
|
return result.images[0]
|
||||||
|
|
||||||
output = process_tiles(source_image, SizeChart.auto, 1, [outpaint])
|
output = process_tile_spiral(source_image, SizeChart.auto, 1, [outpaint])
|
||||||
|
|
||||||
logger.info('final output image size: %sx%s', output.width, output.height)
|
logger.info('final output image size: %sx%s', output.width, output.height)
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -10,7 +10,7 @@ class TileCallback(Protocol):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def process_tiles(
|
def process_tile_grid(
|
||||||
source: Image.Image,
|
source: Image.Image,
|
||||||
tile: int,
|
tile: int,
|
||||||
scale: int,
|
scale: int,
|
||||||
|
@ -37,3 +37,47 @@ def process_tiles(
|
||||||
image.paste(tile_image, (left * scale, top * scale))
|
image.paste(tile_image, (left * scale, top * scale))
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def process_tile_spiral(
|
||||||
|
source: Image.Image,
|
||||||
|
tile: int,
|
||||||
|
scale: int,
|
||||||
|
filters: List[TileCallback],
|
||||||
|
overlap: float = 0.5,
|
||||||
|
) -> Image.Image:
|
||||||
|
if scale != 1:
|
||||||
|
raise Exception('unsupported scale')
|
||||||
|
|
||||||
|
width, height = source.size
|
||||||
|
image = Image.new('RGB', (width * scale, height * scale))
|
||||||
|
image.paste(source, (0, 0))
|
||||||
|
|
||||||
|
# TODO: only valid for overlap = 0.5
|
||||||
|
if overlap == 0.5:
|
||||||
|
tiles = [
|
||||||
|
(0, tile * -overlap),
|
||||||
|
(tile * overlap, tile * -overlap),
|
||||||
|
(tile * overlap, 0),
|
||||||
|
(tile * overlap, tile * overlap),
|
||||||
|
(0, tile * overlap),
|
||||||
|
(tile * -overlap, tile * -overlap),
|
||||||
|
(tile * -overlap, 0),
|
||||||
|
(tile * -overlap, tile * overlap),
|
||||||
|
]
|
||||||
|
|
||||||
|
# tile tuples is source, multiply by scale for dest
|
||||||
|
counter = 0
|
||||||
|
for left, top in tiles:
|
||||||
|
logger.info('processing tile %s of %s', counter, len(tiles))
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
# TODO: only valid for scale == 1, resize source for others
|
||||||
|
tile_image = image.crop((left, top, left + tile, top + tile))
|
||||||
|
|
||||||
|
for filter in filters:
|
||||||
|
tile_image = filter(tile_image, (left, top, tile))
|
||||||
|
|
||||||
|
image.paste(tile_image, (left * scale, top * scale))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
|
@ -97,15 +97,16 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
|
||||||
if val is None:
|
if val is None:
|
||||||
return SizeChart.auto
|
return SizeChart.auto
|
||||||
|
|
||||||
if type(val) is str:
|
|
||||||
if val in SizeChart:
|
|
||||||
return SizeChart[val]
|
|
||||||
else:
|
|
||||||
return int(val)
|
|
||||||
|
|
||||||
if type(val) is int:
|
if type(val) is int:
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
if type(val) is str:
|
||||||
|
for size in SizeChart:
|
||||||
|
if val == size.name:
|
||||||
|
return size
|
||||||
|
|
||||||
|
return int(val)
|
||||||
|
|
||||||
raise Exception('invalid size')
|
raise Exception('invalid size')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@
|
||||||
"name": "save-local",
|
"name": "save-local",
|
||||||
"type": "persist-disk",
|
"type": "persist-disk",
|
||||||
"params": {
|
"params": {
|
||||||
"tile_size": "8k"
|
"tile_size": "hd8k"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -40,7 +40,7 @@
|
||||||
"bucket": "storage-stable-diffusion",
|
"bucket": "storage-stable-diffusion",
|
||||||
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
|
"endpoint_url": "http://scylla.home.holdmyran.ch:8000",
|
||||||
"profile_name": "ceph",
|
"profile_name": "ceph",
|
||||||
"tile_size": "8k"
|
"tile_size": "hd8k"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue