1
0
Fork 0

rewrite tile handling for image stacks

This commit is contained in:
Sean Sube 2023-11-19 18:39:39 -06:00
parent eb77c83d80
commit 98fcc07524
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
13 changed files with 254 additions and 201 deletions

View File

@ -29,5 +29,8 @@ class BlendLinearStage(BaseStage):
logger.info("blending source images using linear interpolation")
return StageResult(
images=[Image.blend(source, stage_source, alpha) for source in sources.as_image()]
images=[
Image.blend(source, stage_source, alpha)
for source in sources.as_image()
]
)

View File

@ -40,6 +40,7 @@ class BlendMaskStage(BaseStage):
return StageResult(
images=[
Image.composite(stage_source, source, mult_mask) for source in sources.as_image()
Image.composite(stage_source, source, mult_mask)
for source in sources.as_image()
]
)

View File

@ -29,9 +29,7 @@ class PersistDiskStage(BaseStage):
stage_source: Optional[Image.Image] = None,
**kwargs,
) -> StageResult:
logger.info(
"persisting %s images to disk: %s", len(sources), output
)
logger.info("persisting %s images to disk: %s", len(sources), output)
for source, name in zip(sources, output):
dest = save_image(server, name, source, params=params, size=size)

View File

@ -12,8 +12,8 @@ from ..server import ServerContext
from ..utils import is_debug, run_gc
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .tile import needs_tile, process_tile_order
from .result import StageResult
from .tile import needs_tile, process_tile_order
logger = getLogger(__name__)
@ -163,19 +163,17 @@ class ChainPipeline:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
# TODO: stage_sources will always be defined here
if stage_sources or must_tile:
stage_results = []
for source in stage_sources.as_image():
if must_tile:
logger.info(
"image contains sources or is larger than tile size of %s, tiling stage",
tile,
)
def stage_tile(
source_tile: Image.Image,
source_tile: List[Image.Image],
tile_mask: Image.Image,
dims: Tuple[int, int, int],
) -> StageResult:
) -> List[Image.Image]:
for _i in range(worker.retries):
try:
tile_result = stage_pipe.run(
@ -194,8 +192,7 @@ class ChainPipeline:
for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image)
# TODO: return whole result
return tile_result.as_image()[0]
return tile_result.as_image()
except Exception:
worker.retries = worker.retries - 1
logger.exception(
@ -207,17 +204,15 @@ class ChainPipeline:
raise RetryException("exhausted retries on tile")
output = process_tile_order(
stage_results = process_tile_order(
stage_params.tile_order,
source,
stage_sources,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_results.append(output)
stage_sources = StageResult(images=stage_results)
else:
logger.debug(

View File

@ -2,7 +2,7 @@ import itertools
from enum import Enum
from logging import getLogger
from math import ceil
from typing import Any, List, Optional, Protocol, Tuple
from typing import Any, Callable, List, Optional, Protocol, Tuple
import numpy as np
from PIL import Image
@ -17,12 +17,17 @@ from .result import StageResult
logger = getLogger(__name__)
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
class TileCallback(Protocol):
"""
Definition for a tile job function.
"""
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> StageResult:
def __call__(
self, image: Image.Image, dims: Tuple[int, int, int]
) -> List[Image.Image]:
"""
Run this stage against a single tile.
"""
@ -33,6 +38,9 @@ def complete_tile(
source: Image.Image,
tile: int,
) -> Image.Image:
"""
TODO: clean up
"""
if source is None:
return source
@ -67,7 +75,7 @@ def needs_tile(
return False
def get_tile_grads(
def make_tile_grads(
left: int,
top: int,
tile: int,
@ -161,7 +169,7 @@ def blend_tiles(
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
# gradient blending
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height)
grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
@ -225,60 +233,18 @@ def blend_tiles(
return Image.fromarray(np.uint8(pixels))
def process_tile_grid(
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
overlap: float = 0.0,
**kwargs,
) -> Image.Image:
width, height = kwargs.get("size", source.size if source else None)
adj_tile = int(float(tile) * (1.0 - overlap))
tiles_x = ceil(width / adj_tile)
tiles_y = ceil(height / adj_tile)
total = tiles_x * tiles_y
logger.debug(
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
total,
tiles_x,
tiles_y,
adj_tile,
overlap,
)
tiles: List[Tuple[int, int, Image.Image]] = []
for y in range(tiles_y):
for x in range(tiles_x):
idx = (y * tiles_x) + x
left = x * adj_tile
top = y * adj_tile
logger.info("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
tile_image = (
source.crop((left, top, left + tile, top + tile)) if source else None
)
tile_image = complete_tile(tile_image, tile)
for filter in filters:
tile_image = filter(tile_image, (left, top, tile))
tiles.append((left, top, tile_image))
return blend_tiles(tiles, scale, width, height, tile, overlap)
def process_tile_spiral(
source: Image.Image,
def process_tile_stack(
stack: StageResult,
tile: int,
scale: int,
filters: List[TileCallback],
tile_generator: TileGenerator,
overlap: float = 0.5,
**kwargs,
) -> Image.Image:
width, height = kwargs.get("size", source.size if source else None)
) -> List[Image.Image]:
sources = stack.as_image()
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
mask = kwargs.get("mask", None)
noise_source = kwargs.get("noise_source", noise_source_histogram)
fill_color = kwargs.get("fill_color", None)
@ -286,18 +252,9 @@ def process_tile_spiral(
tile_mask = None
tiles: List[Tuple[int, int, Image.Image]] = []
tile_coords = tile_generator(width, height, tile, overlap)
# tile tuples is source, multiply by scale for dest
counter = 0
tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap)
if len(tile_coords) == 1:
single_tile = True
else:
single_tile = False
for left, top in tile_coords:
counter += 1
for counter, (left, top) in enumerate(tile_coords):
logger.info(
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
)
@ -321,37 +278,28 @@ def process_tile_spiral(
needs_margin = True
bottom_margin = height - bottom
# if no source given, we don't have a source image
if not source:
tile_image = None
elif needs_margin:
# in the special case where the image is smaller than the specified tile size, just use the image
if single_tile:
logger.debug("creating and processing single-tile subtile")
tile_image = source
if mask:
tile_mask = mask
# otherwise use add histogram noise outside of the image border
else:
if needs_margin:
logger.debug(
"tiling and adding margins: %s, %s, %s, %s",
"tiling with added margins: %s, %s, %s, %s",
left_margin,
top_margin,
right_margin,
bottom_margin,
)
base_image = source.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
tile_stack = add_margin(
stack,
left,
top,
right,
bottom,
left_margin,
top_margin,
right_margin,
bottom_margin,
tile,
noise_source,
fill_color,
)
)
tile_image = noise_source(
base_image, (tile, tile), (0, 0), fill=fill_color
)
tile_image.paste(base_image, (left_margin, top_margin))
if mask:
base_mask = mask.crop(
@ -367,24 +315,30 @@ def process_tile_spiral(
else:
logger.debug("tiling normally")
tile_image = source.crop((left, top, right, bottom))
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
if mask:
tile_mask = mask.crop((left, top, right, bottom))
for image_filter in filters:
tile_image = image_filter(tile_image, tile_mask, (left, top, tile))
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
tiles.append((left, top, tile_image))
tiles.append((left, top, tile_stack))
if single_tile:
return tile_image
else:
return blend_tiles(tiles, scale, width, height, tile, overlap)
lefts, tops, stacks = list(zip(*tiles))
coords = list(zip(lefts, tops))
stacks = list(zip(*stacks))
result = []
for stack in stacks:
stack_tiles = zip(coords, stack)
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
return result
def process_tile_order(
order: TileOrder,
source: Image.Image,
stack: StageResult,
tile: int,
scale: int,
filters: List[TileCallback],
@ -395,13 +349,17 @@ def process_tile_order(
"""
if order == TileOrder.grid:
logger.debug("using grid tile order with tile size: %s", tile)
return process_tile_grid(source, tile, scale, filters, **kwargs)
return process_tile_stack(
stack, tile, scale, filters, generate_tile_grid, **kwargs
)
elif order == TileOrder.kernel:
logger.debug("using kernel tile order with tile size: %s", tile)
raise NotImplementedError()
elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile)
return process_tile_spiral(source, tile, scale, filters, **kwargs)
return process_tile_stack(
stack, tile, scale, filters, generate_tile_spiral, **kwargs
)
else:
logger.warning("unknown tile order: %s", order)
raise ValueError()
@ -495,3 +453,76 @@ def generate_tile_spiral(
height_tile_target -= abs(state.value[1])
return tile_coords
def generate_tile_grid(
width: int,
height: int,
tile: int,
overlap: float = 0.0,
) -> List[Tuple[int, int]]:
adj_tile = int(float(tile) * (1.0 - overlap))
tiles_x = ceil(width / adj_tile)
tiles_y = ceil(height / adj_tile)
total = tiles_x * tiles_y
logger.debug(
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
total,
tiles_x,
tiles_y,
adj_tile,
overlap,
)
tiles: List[Tuple[int, int, Image.Image]] = []
for y in range(tiles_y):
for x in range(tiles_x):
left = x * adj_tile
top = y * adj_tile
tiles.append((int(left), int(top)))
return tiles
def get_result_tile(
result: StageResult,
origin: Tuple[int, int],
tile: Size,
) -> List[Image.Image]:
top, left = origin
return [
layer.crop((top, left, top + tile.height, left + tile.width))
for layer in result.as_image()
]
def add_margin(
stack: List[Image.Image],
left: int,
top: int,
right: int,
bottom: int,
left_margin: int,
top_margin: int,
right_margin: int,
bottom_margin: int,
tile: int,
noise_source,
fill_color,
) -> List[Image.Image]:
results = []
for source in stack:
base_image = source.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
)
)
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
tile_image.paste(base_image, (left_margin, top_margin))
return results

View File

@ -79,12 +79,15 @@ class UpscaleBSRGANStage(BaseStage):
logger.trace("BSRGAN input shape: %s", image.shape)
scale = upscale.outscale
logger.trace("BSRGAN output shape: %s", (
logger.trace(
"BSRGAN output shape: %s",
(
image.shape[0],
image.shape[1],
image.shape[2] * scale,
image.shape[3] * scale,
))
),
)
output = bsrgan(image)

View File

@ -38,10 +38,14 @@ class UpscaleSimpleStage(BaseStage):
if method == "bilinear":
logger.debug("using bilinear interpolation for highres")
outputs.append(source.resize(scaled_size, resample=Image.Resampling.BILINEAR))
outputs.append(
source.resize(scaled_size, resample=Image.Resampling.BILINEAR)
)
elif method == "lanczos":
logger.debug("using Lanczos interpolation for highres")
outputs.append(source.resize(scaled_size, resample=Image.Resampling.LANCZOS))
outputs.append(
source.resize(scaled_size, resample=Image.Resampling.LANCZOS)
)
else:
logger.warning("unknown upscaling method: %s", method)

View File

@ -105,7 +105,9 @@ def run_txt2img_pipeline(
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback()
images = chain.run(worker, server, params, StageResult.empty(), callback=progress, latents=latents)
images = chain.run(
worker, server, params, StageResult.empty(), callback=progress, latents=latents
)
_pairs, loras, inversions, _rest = parse_prompt(params)
@ -200,7 +202,9 @@ def run_img2img_pipeline(
# run and append the filtered source
progress = worker.get_progress_callback()
images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
images = chain.run(
worker, server, params, StageResult(images=[source]), callback=progress
)
if source_filter is not None and source_filter != "none":
images.append(source)
@ -380,7 +384,9 @@ def run_inpaint_pipeline(
# run and save
latents = get_latents_from_seed(params.seed, size, batch=params.batch)
progress = worker.get_progress_callback()
images = chain.run(worker, server, params, [source], callback=progress, latents=latents)
images = chain.run(
worker, server, params, [source], callback=progress, latents=latents
)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
@ -455,7 +461,9 @@ def run_upscale_pipeline(
# run and save
progress = worker.get_progress_callback()
images = chain.run(worker, server, params, StageResult(images=[source]), callback=progress)
images = chain.run(
worker, server, params, StageResult(images=[source]), callback=progress
)
_pairs, loras, inversions, _rest = parse_prompt(params)
for image, output in zip(images, outputs):
@ -505,7 +513,9 @@ def run_blend_pipeline(
# run and save
progress = worker.get_progress_callback()
images = chain.run(worker, server, params, StageResult(images=sources), callback=progress)
images = chain.run(
worker, server, params, StageResult(images=sources), callback=progress
)
for image, output in zip(images, outputs):
dest = save_image(server, output, image, params, size, upscale=upscale)

View File

@ -1,6 +1,6 @@
import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.result import StageResult
from onnx_web.chain.source_noise import SourceNoiseStage
from onnx_web.image.noise_source import noise_source_fill_edge
from onnx_web.params import HighresParams, Size, UpscaleParams

View File

@ -1,6 +1,6 @@
import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.result import StageResult
from onnx_web.chain.source_s3 import SourceS3Stage
from onnx_web.params import HighresParams, Size, UpscaleParams

View File

@ -1,6 +1,6 @@
import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.result import StageResult
from onnx_web.chain.source_url import SourceURLStage
from onnx_web.params import HighresParams, Size, UpscaleParams

View File

@ -4,11 +4,11 @@ from PIL import Image
from onnx_web.chain.tile import (
complete_tile,
generate_tile_grid,
generate_tile_spiral,
get_tile_grads,
make_tile_grads,
needs_tile,
process_tile_grid,
process_tile_spiral,
process_tile_stack,
)
from onnx_web.params import Size
@ -59,24 +59,46 @@ class TestNeedsTile(unittest.TestCase):
class TestTileGrads(unittest.TestCase):
def test_center_tile(self):
grad_x, grad_y = get_tile_grads(32, 32, 8, 64, 64)
grad_x, grad_y = make_tile_grads(32, 32, 8, 64, 64)
self.assertEqual(grad_x, [0, 1, 1, 0])
self.assertEqual(grad_y, [0, 1, 1, 0])
def test_vertical_edge_tile(self):
grad_x, grad_y = get_tile_grads(32, 0, 8, 64, 8)
grad_x, grad_y = make_tile_grads(32, 0, 8, 64, 8)
self.assertEqual(grad_x, [0, 1, 1, 0])
self.assertEqual(grad_y, [1, 1, 1, 1])
def test_horizontal_edge_tile(self):
grad_x, grad_y = get_tile_grads(0, 32, 8, 8, 64)
grad_x, grad_y = make_tile_grads(0, 32, 8, 8, 64)
self.assertEqual(grad_x, [1, 1, 1, 1])
self.assertEqual(grad_y, [0, 1, 1, 0])
class TestGenerateTileGrid(unittest.TestCase):
def test_grid_complete(self):
tiles = generate_tile_grid(16, 16, 8, 0.0)
self.assertEqual(len(tiles), 4)
self.assertEqual(tiles, [(0, 0), (8, 0), (8, 8), (0, 8)])
def test_grid_no_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.0)
self.assertEqual(len(tiles), 64)
self.assertEqual(tiles[0:4], [(0, 0), (8, 0), (16, 0), (24, 0)])
self.assertEqual(tiles[-5:-1], [(16, 24), (24, 24), (32, 24), (32, 32)])
def test_grid_50_overlap(self):
tiles = generate_tile_grid(64, 64, 8, 0.5)
self.assertEqual(len(tiles), 225)
self.assertEqual(tiles[0:4], [(0, 0), (4, 0), (8, 0), (12, 0)])
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
class TestGenerateTileSpiral(unittest.TestCase):
def test_spiral_complete(self):
tiles = generate_tile_spiral(16, 16, 8, 0.0)
@ -99,29 +121,15 @@ class TestGenerateTileSpiral(unittest.TestCase):
self.assertEqual(tiles[-5:-1], [(32, 32), (28, 32), (24, 32), (24, 28)])
class TestProcessTileGrid(unittest.TestCase):
class TestProcessTileStack(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_grid(source, 32, 1, [])
blend = process_tile_stack(source, 32, 1, [])
self.assertEqual(blend.size, (64, 64))
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_grid(source, 32, 1, [])
self.assertEqual(blend.size, (72, 72))
class TestProcessTileSpiral(unittest.TestCase):
def test_grid_full(self):
source = Image.new("RGB", (64, 64))
blend = process_tile_spiral(source, 32, 1, [])
self.assertEqual(blend.size, (64, 64))
def test_grid_partial(self):
source = Image.new("RGB", (72, 72))
blend = process_tile_spiral(source, 32, 1, [])
blend = process_tile_stack(source, 32, 1, [])
self.assertEqual(blend.size, (72, 72))

View File

@ -1,6 +1,6 @@
import unittest
from onnx_web.chain.result import StageResult
from onnx_web.chain.result import StageResult
from onnx_web.chain.upscale_highres import UpscaleHighresStage
from onnx_web.params import HighresParams, UpscaleParams