1
0
Fork 0
onnx-web/api/onnx_web/chain/tile.py

552 lines
15 KiB
Python
Raw Normal View History

2023-07-08 14:03:06 +00:00
import itertools
2023-07-08 14:17:35 +00:00
from enum import Enum
2023-01-28 23:09:19 +00:00
from logging import getLogger
2023-04-04 02:39:10 +00:00
from math import ceil
2023-11-26 03:19:11 +00:00
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union
import numpy as np
2023-02-05 13:53:26 +00:00
from PIL import Image
from ..image.noise_source import noise_source_histogram
2023-07-08 14:03:06 +00:00
from ..params import Size, TileOrder
from .result import StageResult
2023-06-04 02:00:59 +00:00
# from skimage.exposure import match_histograms
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
2023-11-20 00:39:39 +00:00
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
class TileCallback(Protocol):
"""
Definition for a tile job function.
"""
2023-12-27 02:21:34 +00:00
def __call__(
self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int]
) -> StageResult:
"""
Run this stage against a single tile.
"""
pass
def complete_tile(
source: Image.Image,
tile: int,
) -> Image.Image:
2023-11-20 00:39:39 +00:00
"""
TODO: clean up
"""
if source is None:
return source
if source.width < tile or source.height < tile:
full_source = Image.new(source.mode, (tile, tile))
full_source.paste(source)
return full_source
return source
def needs_tile(
max_tile: int,
stage_tile: int,
size: Optional[Size] = None,
source: Optional[Image.Image] = None,
) -> bool:
tile = min(max_tile, stage_tile)
2023-09-13 00:17:03 +00:00
logger.trace(
"checking image tile dimensions: %s, %s, %s",
tile,
2023-09-14 03:03:39 +00:00
source.width > tile or source.height > tile if source is not None else False,
size.width > tile or size.height > tile if size is not None else False,
2023-09-13 00:17:03 +00:00
)
if source is not None:
return source.width > tile or source.height > tile
if size is not None:
return size.width > tile or size.height > tile
return False
2023-11-20 00:39:39 +00:00
def make_tile_grads(
2023-06-03 17:22:50 +00:00
left: int,
top: int,
tile: int,
width: int,
height: int,
2023-06-03 14:51:44 +00:00
) -> Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float]]:
grad_x = [0, 1, 1, 0]
grad_y = [0, 1, 1, 0]
2023-06-03 14:51:44 +00:00
if left <= 0:
grad_x[0] = 1
2023-06-03 14:51:44 +00:00
if top <= 0:
grad_y[0] = 1
2023-06-03 14:51:44 +00:00
2023-06-03 17:22:50 +00:00
if (left + tile) >= width:
grad_x[3] = 1
2023-06-03 14:51:44 +00:00
2023-06-03 17:22:50 +00:00
if (top + tile) >= height:
grad_y[3] = 1
2023-06-03 14:51:44 +00:00
return (grad_x, grad_y)
def make_tile_mask(
shape: Any,
2023-11-10 04:42:45 +00:00
tile: Tuple[int, int],
overlap: float,
edges: Tuple[bool, bool, bool, bool],
) -> np.ndarray:
2023-11-09 01:07:41 +00:00
mask = np.ones(shape)
2023-11-10 04:42:45 +00:00
tile_h, tile_w = tile
adj_tile_h = int(float(tile_h) * (1.0 - overlap))
adj_tile_w = int(float(tile_w) * (1.0 - overlap))
# sort gradient points
p1_h = adj_tile_h - 1
2023-11-10 04:42:45 +00:00
p2_h = tile_h - adj_tile_h
2023-11-13 05:28:47 +00:00
points_h = [-1, min(p1_h, p2_h), max(p1_h, p2_h), tile_h]
2023-11-10 04:42:45 +00:00
p1_w = adj_tile_w - 1
2023-11-10 04:42:45 +00:00
p2_w = tile_w - adj_tile_w
2023-11-13 05:28:47 +00:00
points_w = [-1, min(p1_w, p2_w), max(p1_w, p2_w), tile_w]
# build gradients
2023-11-12 02:24:02 +00:00
edge_t, edge_l, edge_b, edge_r = edges
2023-11-12 22:38:56 +00:00
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
int(not edge_t),
1,
1,
int(not edge_b),
]
2023-11-10 04:42:45 +00:00
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
2023-11-10 04:42:45 +00:00
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)]
mask = ((mask * mult_x).T * mult_y).T
return mask
2023-11-26 03:19:11 +00:00
def get_channels(image: Union[np.ndarray, Image.Image]) -> int:
if isinstance(image, np.ndarray):
return image.shape[-1]
if image.mode == "RGBA":
return 4
elif image.mode == "RGB":
return 3
elif image.mode == "L":
return 1
raise ValueError("unknown image format")
def blend_tiles(
2023-06-04 01:35:33 +00:00
tiles: List[Tuple[int, int, Image.Image]],
scale: int,
width: int,
height: int,
tile: int,
overlap: float,
):
2023-06-04 01:56:56 +00:00
adj_tile = int(float(tile) * (1.0 - overlap))
logger.debug(
2023-06-08 12:20:19 +00:00
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
)
2023-11-26 03:19:11 +00:00
channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles])
scaled_size = (height * scale, width * scale, channels)
2023-06-03 14:51:44 +00:00
count = np.zeros(scaled_size)
value = np.zeros(scaled_size)
for left, top, tile_image in tiles:
equalized = np.array(tile_image).astype(np.float32)
mask = np.ones_like(equalized[:, :, 0])
2023-06-03 14:51:44 +00:00
if adj_tile < tile:
# sort gradient points
p1 = (adj_tile * scale) - 1
2023-11-13 05:28:47 +00:00
p2 = (tile - adj_tile - 1) * scale
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
# gradient blending
2023-11-20 00:39:39 +00:00
grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
mult_y = [np.interp(i, points, grad_y) for i in range(tile * scale)]
2023-06-03 14:51:44 +00:00
mask = ((mask * mult_x).T * mult_y).T
for c in range(3):
equalized[:, :, c] = equalized[:, :, c] * mask
2023-06-03 14:51:44 +00:00
2023-06-03 18:50:27 +00:00
scaled_top = top * scale
scaled_left = left * scale
# equalized size may be wrong/too much
2023-07-07 02:46:36 +00:00
scaled_bottom = scaled_top + equalized.shape[0]
scaled_right = scaled_left + equalized.shape[1]
2023-07-08 14:17:35 +00:00
writable_top = max(scaled_top, 0)
writable_left = max(scaled_left, 0)
writable_bottom = min(scaled_bottom, scaled_size[0])
writable_right = min(scaled_right, scaled_size[1])
2023-07-07 02:46:36 +00:00
margin_top = writable_top - scaled_top
margin_left = writable_left - scaled_left
margin_bottom = writable_bottom - scaled_bottom
margin_right = writable_right - scaled_right
2023-07-08 14:17:35 +00:00
2023-07-07 02:46:36 +00:00
logger.debug(
2023-07-08 14:03:06 +00:00
"tile broadcast shapes: %s, %s, %s, %s \n writing shapes: %s, %s, %s, %s",
2023-07-07 02:46:36 +00:00
writable_top,
writable_left,
writable_bottom,
writable_right,
margin_top,
equalized.shape[0] + margin_bottom,
2023-07-08 14:03:06 +00:00
margin_left,
2023-07-07 02:46:36 +00:00
equalized.shape[0] + margin_right,
2023-06-08 12:20:19 +00:00
)
2023-06-03 18:28:09 +00:00
# accumulation
2023-07-08 14:17:35 +00:00
value[
writable_top:writable_bottom, writable_left:writable_right, :
] += equalized[
margin_top : equalized.shape[0] + margin_bottom,
margin_left : equalized.shape[1] + margin_right,
:,
2023-06-04 01:35:33 +00:00
]
2023-07-08 14:17:35 +00:00
count[
writable_top:writable_bottom, writable_left:writable_right, :
] += np.repeat(
2023-06-04 01:35:33 +00:00
mask[
2023-07-07 02:46:36 +00:00
margin_top : equalized.shape[0] + margin_bottom,
margin_left : equalized.shape[1] + margin_right,
2023-06-04 01:35:33 +00:00
np.newaxis,
],
channels,
2023-06-04 01:35:33 +00:00
axis=2,
)
logger.trace("mean tiles contributing to each pixel: %s", np.mean(count))
pixels = np.where(count > 0, value / count, value)
return Image.fromarray(np.uint8(pixels))
2023-11-20 00:39:39 +00:00
def process_tile_stack(
stack: StageResult,
tile: int,
scale: int,
filters: List[TileCallback],
2023-11-20 00:39:39 +00:00
tile_generator: TileGenerator,
overlap: float = 0.5,
**kwargs,
2023-11-20 00:39:39 +00:00
) -> List[Image.Image]:
sources = stack.as_image()
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
2023-12-21 05:33:13 +00:00
mask = kwargs.get("mask", kwargs.get("stage_mask", None))
2023-07-09 05:02:27 +00:00
noise_source = kwargs.get("noise_source", noise_source_histogram)
fill_color = kwargs.get("fill_color", None)
2023-07-09 04:56:20 +00:00
if not mask:
tile_mask = None
2023-07-09 05:02:27 +00:00
tiles: List[Tuple[int, int, Image.Image]] = []
2023-11-20 00:39:39 +00:00
tile_coords = tile_generator(width, height, tile, overlap)
single_tile = len(tile_coords) == 1
2023-11-20 00:39:39 +00:00
for counter, (left, top) in enumerate(tile_coords):
logger.info(
2023-06-04 01:35:33 +00:00
"processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top
)
2023-07-08 14:17:35 +00:00
2023-07-07 02:46:36 +00:00
right = left + tile
bottom = top + tile
2023-07-08 14:17:35 +00:00
2023-07-07 02:46:36 +00:00
left_margin = right_margin = top_margin = bottom_margin = 0
needs_margin = False
2023-07-08 14:17:35 +00:00
2023-07-07 02:46:36 +00:00
if left < 0:
needs_margin = True
left_margin = 0 - left
if right > width:
needs_margin = True
right_margin = width - right
if top < 0:
needs_margin = True
top_margin = 0 - top
if bottom > height:
needs_margin = True
bottom_margin = height - bottom
2023-07-09 05:02:27 +00:00
if single_tile:
logger.debug("using single tile")
tile_stack = sources
if mask:
tile_mask = mask
elif needs_margin:
2023-11-20 00:39:39 +00:00
logger.debug(
"tiling with added margins: %s, %s, %s, %s",
left_margin,
top_margin,
right_margin,
bottom_margin,
)
tile_stack = add_margin(
2023-11-20 03:05:21 +00:00
stack.as_image(),
2023-11-20 00:39:39 +00:00
left,
top,
right,
bottom,
left_margin,
top_margin,
right_margin,
bottom_margin,
tile,
noise_source,
fill_color,
)
if mask:
base_mask = mask.crop(
2023-07-09 05:02:27 +00:00
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
)
)
2023-11-20 00:39:39 +00:00
tile_mask = Image.new("L", (tile, tile), color=0)
tile_mask.paste(base_mask, (left_margin, top_margin))
2023-07-09 05:02:27 +00:00
2023-07-07 02:46:36 +00:00
else:
logger.debug("tiling normally")
2023-11-20 00:39:39 +00:00
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
2023-07-09 04:56:20 +00:00
if mask:
tile_mask = mask.crop((left, top, right, bottom))
2023-07-07 02:46:36 +00:00
for image_filter in filters:
2023-11-20 00:39:39 +00:00
tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile))
2023-11-20 05:40:34 +00:00
if isinstance(tile_stack, list):
tile_stack = StageResult.from_images(tile_stack)
2023-11-20 03:05:21 +00:00
tiles.append((left, top, tile_stack.as_image()))
2023-07-09 05:02:27 +00:00
2023-11-20 00:39:39 +00:00
lefts, tops, stacks = list(zip(*tiles))
coords = list(zip(lefts, tops))
stacks = list(zip(*stacks))
result = []
for stack in stacks:
stack_tiles = zip(coords, stack)
2023-11-20 03:05:21 +00:00
stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles]
2023-11-20 00:39:39 +00:00
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
return result
def process_tile_order(
order: TileOrder,
2023-11-20 00:39:39 +00:00
stack: StageResult,
tile: int,
scale: int,
filters: List[TileCallback],
**kwargs,
2023-12-26 14:06:16 +00:00
) -> List[Image.Image]:
if order == TileOrder.grid:
logger.debug("using grid tile order with tile size: %s", tile)
2023-11-20 00:39:39 +00:00
return process_tile_stack(
stack, tile, scale, filters, generate_tile_grid, **kwargs
)
elif order == TileOrder.kernel:
logger.debug("using kernel tile order with tile size: %s", tile)
raise NotImplementedError()
elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile)
2023-11-20 00:39:39 +00:00
return process_tile_stack(
stack, tile, scale, filters, generate_tile_spiral, **kwargs
)
else:
logger.warning("unknown tile order: %s", order)
raise ValueError()
2023-04-04 02:39:10 +00:00
def generate_tile_spiral(
width: int,
height: int,
tile: int,
overlap: float = 0.0,
2023-04-04 02:39:10 +00:00
) -> List[Tuple[int, int]]:
spacing = 1.0 - overlap
2023-07-08 14:17:35 +00:00
tile_increment = (
round(tile * spacing / 2) * 2
) # dividing and then multiplying by 2 ensures this will be an even number, which is necessary for the initial tile placement calculation
# calculate the number of tiles needed
2023-07-07 02:46:36 +00:00
if width > tile:
width_tile_target = 1 + ceil((width - tile) / tile_increment)
else:
width_tile_target = 1
2023-07-07 02:46:36 +00:00
if height > tile:
height_tile_target = 1 + ceil((height - tile) / tile_increment)
else:
height_tile_target = 1
2023-07-08 14:17:35 +00:00
# calculate the start position of the tiling
span_x = tile + (width_tile_target - 1) * tile_increment
span_y = tile + (height_tile_target - 1) * tile_increment
2023-07-09 05:02:27 +00:00
logger.debug("tiled image overlap: %s. Span: %s x %s", overlap, span_x, span_y)
2023-07-08 14:17:35 +00:00
tile_left = (
width - span_x
) / 2 # guaranteed to be an integer because width and span will both be even
tile_top = (
height - span_y
) / 2 # guaranteed to be an integer because width and span will both be even
2023-07-07 02:46:36 +00:00
logger.debug(
"image size %s x %s, tiling to %s x %s, starting at %s, %s",
width,
height,
width_tile_target,
height_tile_target,
tile_left,
2023-07-08 14:17:35 +00:00
tile_top,
2023-07-07 02:46:36 +00:00
)
2023-07-08 14:17:35 +00:00
2023-07-07 02:46:36 +00:00
tile_coords = []
2023-04-04 02:39:10 +00:00
# start walking from the north-west corner, heading east
2023-07-07 02:46:36 +00:00
class WalkState(Enum):
2023-07-08 14:17:35 +00:00
EAST = (1, 0)
SOUTH = (0, 1)
WEST = (-1, 0)
NORTH = (0, -1)
# initialize the tile_left placement
2023-07-07 02:46:36 +00:00
tile_left -= tile_increment
height_tile_target -= 1
2023-07-08 14:17:35 +00:00
2023-07-07 02:46:36 +00:00
for state in itertools.cycle(WalkState):
2023-07-08 14:17:35 +00:00
# This expression is stupid, but all it does is calculate the number of tiles we need in the appropriate direction
accum_tile_target = max(
map(
lambda coord, val: abs(coord * val),
state.value,
(width_tile_target, height_tile_target),
)
)
# check if done
2023-07-07 02:46:36 +00:00
if accum_tile_target == 0:
break
2023-07-08 14:17:35 +00:00
# reset tile count
2023-07-07 02:46:36 +00:00
accum_tiles = 0
while accum_tiles < accum_tile_target:
# move to the next
2023-07-08 14:17:35 +00:00
tile_left += tile_increment * state.value[0]
tile_top += tile_increment * state.value[1]
2023-04-04 02:39:10 +00:00
# add a tile
2023-07-08 14:17:35 +00:00
logger.debug("adding tile at %s:%s", tile_left, tile_top)
tile_coords.append((int(tile_left), int(tile_top)))
2023-04-04 02:39:10 +00:00
2023-07-07 02:46:36 +00:00
accum_tiles += 1
2023-07-08 14:17:35 +00:00
2023-07-07 02:46:36 +00:00
width_tile_target -= abs(state.value[0])
height_tile_target -= abs(state.value[1])
2023-04-04 02:39:10 +00:00
return tile_coords
2023-11-20 00:39:39 +00:00
def generate_tile_grid(
width: int,
height: int,
tile: int,
overlap: float = 0.0,
2023-12-26 14:06:16 +00:00
) -> List[Tuple[int, int, Image.Image]]:
2023-11-20 00:39:39 +00:00
adj_tile = int(float(tile) * (1.0 - overlap))
tiles_x = ceil(width / adj_tile)
tiles_y = ceil(height / adj_tile)
total = tiles_x * tiles_y
logger.debug(
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
total,
tiles_x,
tiles_y,
adj_tile,
overlap,
)
tiles: List[Tuple[int, int, Image.Image]] = []
for y in range(tiles_y):
for x in range(tiles_x):
left = x * adj_tile
top = y * adj_tile
tiles.append((int(left), int(top)))
return tiles
def get_result_tile(
result: StageResult,
origin: Tuple[int, int],
tile: Size,
) -> List[Image.Image]:
top, left = origin
return [
layer.crop((top, left, top + tile.height, left + tile.width))
for layer in result.as_image()
]
def add_margin(
stack: List[Image.Image],
left: int,
top: int,
right: int,
bottom: int,
left_margin: int,
top_margin: int,
right_margin: int,
bottom_margin: int,
tile: int,
noise_source,
fill_color,
) -> List[Image.Image]:
results = []
for source in stack:
base_image = source.crop(
(
left + left_margin,
top + top_margin,
right + right_margin,
bottom + bottom_margin,
)
)
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
tile_image.paste(base_image, (left_margin, top_margin))
2023-11-20 03:22:55 +00:00
results.append(tile_image)
2023-11-20 00:39:39 +00:00
return results