559 lines
16 KiB
Python
559 lines
16 KiB
Python
import itertools
|
|
from enum import Enum
|
|
from logging import getLogger
|
|
from math import ceil
|
|
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from ..image.noise_source import noise_source_histogram
|
|
from ..params import Size, TileOrder
|
|
from .result import StageResult
|
|
|
|
# from skimage.exposure import match_histograms
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]]
|
|
|
|
|
|
class TileCallback(Protocol):
|
|
"""
|
|
Definition for a tile job function.
|
|
"""
|
|
|
|
def __call__(
|
|
self,
|
|
sources: List[Image.Image],
|
|
mask: Image.Image,
|
|
dims: Tuple[int, int, int],
|
|
progress: Tuple[int, int],
|
|
) -> StageResult:
|
|
"""
|
|
Run this stage against a single tile.
|
|
"""
|
|
pass
|
|
|
|
|
|
def complete_tile(
|
|
source: Image.Image,
|
|
tile: int,
|
|
) -> Image.Image:
|
|
"""
|
|
TODO: clean up
|
|
"""
|
|
if source is None:
|
|
return source
|
|
|
|
if source.width < tile or source.height < tile:
|
|
full_source = Image.new(source.mode, (tile, tile))
|
|
full_source.paste(source)
|
|
return full_source
|
|
|
|
return source
|
|
|
|
|
|
def needs_tile(
|
|
max_tile: int,
|
|
stage_tile: int,
|
|
size: Optional[Size] = None,
|
|
source: Optional[Image.Image] = None,
|
|
) -> bool:
|
|
tile = min(max_tile, stage_tile)
|
|
logger.trace(
|
|
"checking image tile dimensions: %s, %s, %s",
|
|
tile,
|
|
source.width > tile or source.height > tile if source is not None else False,
|
|
size.width > tile or size.height > tile if size is not None else False,
|
|
)
|
|
|
|
if source is not None:
|
|
return source.width > tile or source.height > tile
|
|
|
|
if size is not None:
|
|
return size.width > tile or size.height > tile
|
|
|
|
return False
|
|
|
|
|
|
def make_tile_grads(
|
|
left: int,
|
|
top: int,
|
|
tile: int,
|
|
width: int,
|
|
height: int,
|
|
) -> Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float]]:
|
|
grad_x = [0, 1, 1, 0]
|
|
grad_y = [0, 1, 1, 0]
|
|
|
|
if left <= 0:
|
|
grad_x[0] = 1
|
|
|
|
if top <= 0:
|
|
grad_y[0] = 1
|
|
|
|
if (left + tile) >= width:
|
|
grad_x[3] = 1
|
|
|
|
if (top + tile) >= height:
|
|
grad_y[3] = 1
|
|
|
|
return (grad_x, grad_y)
|
|
|
|
|
|
def make_tile_mask(
|
|
shape: Any,
|
|
tile: Tuple[int, int],
|
|
overlap: float,
|
|
edges: Tuple[bool, bool, bool, bool],
|
|
) -> np.ndarray:
|
|
mask = np.ones(shape)
|
|
|
|
tile_h, tile_w = tile
|
|
|
|
adj_tile_h = int(float(tile_h) * (1.0 - overlap))
|
|
adj_tile_w = int(float(tile_w) * (1.0 - overlap))
|
|
|
|
# sort gradient points
|
|
p1_h = adj_tile_h - 1
|
|
p2_h = tile_h - adj_tile_h
|
|
points_h = [-1, min(p1_h, p2_h), max(p1_h, p2_h), tile_h]
|
|
|
|
p1_w = adj_tile_w - 1
|
|
p2_w = tile_w - adj_tile_w
|
|
points_w = [-1, min(p1_w, p2_w), max(p1_w, p2_w), tile_w]
|
|
|
|
# build gradients
|
|
edge_t, edge_l, edge_b, edge_r = edges
|
|
grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [
|
|
int(not edge_t),
|
|
1,
|
|
1,
|
|
int(not edge_b),
|
|
]
|
|
logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y)
|
|
|
|
mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)]
|
|
mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)]
|
|
|
|
mask = ((mask * mult_x).T * mult_y).T
|
|
|
|
return mask
|
|
|
|
|
|
def get_channels(image: Union[np.ndarray, Image.Image]) -> int:
|
|
if isinstance(image, np.ndarray):
|
|
return image.shape[-1]
|
|
|
|
if image.mode == "RGBA":
|
|
return 4
|
|
elif image.mode == "RGB":
|
|
return 3
|
|
elif image.mode == "L":
|
|
return 1
|
|
|
|
raise ValueError("unknown image format")
|
|
|
|
|
|
def blend_tiles(
|
|
tiles: List[Tuple[int, int, Image.Image]],
|
|
scale: int,
|
|
width: int,
|
|
height: int,
|
|
tile: int,
|
|
overlap: float,
|
|
):
|
|
adj_tile = int(float(tile) * (1.0 - overlap))
|
|
logger.debug(
|
|
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
|
|
)
|
|
|
|
channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles])
|
|
scaled_size = (height * scale, width * scale, channels)
|
|
|
|
count = np.zeros(scaled_size)
|
|
value = np.zeros(scaled_size)
|
|
|
|
for left, top, tile_image in tiles:
|
|
equalized = np.array(tile_image).astype(np.float32)
|
|
mask = np.ones_like(equalized[:, :, 0])
|
|
|
|
if adj_tile < tile:
|
|
# sort gradient points
|
|
p1 = (adj_tile * scale) - 1
|
|
p2 = (tile - adj_tile - 1) * scale
|
|
points = [-1, min(p1, p2), max(p1, p2), (tile * scale)]
|
|
|
|
# gradient blending
|
|
grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height)
|
|
logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y)
|
|
|
|
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
|
|
mult_y = [np.interp(i, points, grad_y) for i in range(tile * scale)]
|
|
|
|
mask = ((mask * mult_x).T * mult_y).T
|
|
for c in range(3):
|
|
equalized[:, :, c] = equalized[:, :, c] * mask
|
|
|
|
scaled_top = top * scale
|
|
scaled_left = left * scale
|
|
|
|
# equalized size may be wrong/too much
|
|
scaled_bottom = scaled_top + equalized.shape[0]
|
|
scaled_right = scaled_left + equalized.shape[1]
|
|
|
|
writable_top = max(scaled_top, 0)
|
|
writable_left = max(scaled_left, 0)
|
|
writable_bottom = min(scaled_bottom, scaled_size[0])
|
|
writable_right = min(scaled_right, scaled_size[1])
|
|
|
|
margin_top = writable_top - scaled_top
|
|
margin_left = writable_left - scaled_left
|
|
margin_bottom = writable_bottom - scaled_bottom
|
|
margin_right = writable_right - scaled_right
|
|
|
|
logger.debug(
|
|
"tile broadcast shapes: %s, %s, %s, %s \n writing shapes: %s, %s, %s, %s",
|
|
writable_top,
|
|
writable_left,
|
|
writable_bottom,
|
|
writable_right,
|
|
margin_top,
|
|
equalized.shape[0] + margin_bottom,
|
|
margin_left,
|
|
equalized.shape[0] + margin_right,
|
|
)
|
|
|
|
# accumulation
|
|
value[
|
|
writable_top:writable_bottom, writable_left:writable_right, :
|
|
] += equalized[
|
|
margin_top : equalized.shape[0] + margin_bottom,
|
|
margin_left : equalized.shape[1] + margin_right,
|
|
:,
|
|
]
|
|
count[
|
|
writable_top:writable_bottom, writable_left:writable_right, :
|
|
] += np.repeat(
|
|
mask[
|
|
margin_top : equalized.shape[0] + margin_bottom,
|
|
margin_left : equalized.shape[1] + margin_right,
|
|
np.newaxis,
|
|
],
|
|
channels,
|
|
axis=2,
|
|
)
|
|
|
|
logger.trace("mean tiles contributing to each pixel: %s", np.mean(count))
|
|
pixels = np.where(count > 0, value / count, value)
|
|
return Image.fromarray(np.uint8(pixels))
|
|
|
|
|
|
def process_tile_stack(
|
|
stack: StageResult,
|
|
tile: int,
|
|
scale: int,
|
|
filters: List[TileCallback],
|
|
tile_generator: TileGenerator,
|
|
overlap: float = 0.5,
|
|
**kwargs,
|
|
) -> List[Image.Image]:
|
|
sources = stack.as_images()
|
|
|
|
width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None)
|
|
mask = kwargs.get("mask", kwargs.get("stage_mask", None))
|
|
noise_source = kwargs.get("noise_source", noise_source_histogram)
|
|
fill_color = kwargs.get("fill_color", None)
|
|
if not mask:
|
|
tile_mask = None
|
|
|
|
tiles: List[Tuple[int, int, Image.Image]] = []
|
|
tile_coords = tile_generator(width, height, tile, overlap)
|
|
|
|
total_tiles = len(tile_coords)
|
|
single_tile = total_tiles == 1
|
|
|
|
for counter, (left, top) in enumerate(tile_coords):
|
|
logger.info("processing tile %s of %s, %sx%s", counter, total_tiles, left, top)
|
|
|
|
right = left + tile
|
|
bottom = top + tile
|
|
|
|
left_margin = right_margin = top_margin = bottom_margin = 0
|
|
needs_margin = False
|
|
|
|
if left < 0:
|
|
needs_margin = True
|
|
left_margin = 0 - left
|
|
if right > width:
|
|
needs_margin = True
|
|
right_margin = width - right
|
|
if top < 0:
|
|
needs_margin = True
|
|
top_margin = 0 - top
|
|
if bottom > height:
|
|
needs_margin = True
|
|
bottom_margin = height - bottom
|
|
|
|
if single_tile:
|
|
logger.debug("using single tile")
|
|
tile_stack = sources
|
|
if mask:
|
|
tile_mask = mask
|
|
elif needs_margin:
|
|
logger.debug(
|
|
"tiling with added margins: %s, %s, %s, %s",
|
|
left_margin,
|
|
top_margin,
|
|
right_margin,
|
|
bottom_margin,
|
|
)
|
|
tile_stack = add_margin(
|
|
stack.as_images(),
|
|
left,
|
|
top,
|
|
right,
|
|
bottom,
|
|
left_margin,
|
|
top_margin,
|
|
right_margin,
|
|
bottom_margin,
|
|
tile,
|
|
noise_source,
|
|
fill_color,
|
|
)
|
|
|
|
if mask:
|
|
base_mask = mask.crop(
|
|
(
|
|
left + left_margin,
|
|
top + top_margin,
|
|
right + right_margin,
|
|
bottom + bottom_margin,
|
|
)
|
|
)
|
|
tile_mask = Image.new("L", (tile, tile), color=0)
|
|
tile_mask.paste(base_mask, (left_margin, top_margin))
|
|
|
|
else:
|
|
logger.debug("tiling normally")
|
|
tile_stack = get_result_tile(stack, (left, top), Size(tile, tile))
|
|
if mask:
|
|
tile_mask = mask.crop((left, top, right, bottom))
|
|
|
|
for image_filter in filters:
|
|
tile_stack = image_filter(
|
|
tile_stack, tile_mask, (left, top, tile), (counter, total_tiles)
|
|
)
|
|
|
|
# TODO: this should be inverted to extract them from the result
|
|
if isinstance(tile_stack, list):
|
|
tile_stack = StageResult.from_images(tile_stack, metadata=stack.metadata)
|
|
|
|
tiles.append((left, top, tile_stack.as_images()))
|
|
|
|
lefts, tops, stacks = list(zip(*tiles))
|
|
coords = list(zip(lefts, tops))
|
|
stacks = list(zip(*stacks))
|
|
|
|
result = []
|
|
for stack in stacks:
|
|
stack_tiles = zip(coords, stack)
|
|
stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles]
|
|
result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap))
|
|
|
|
return result
|
|
|
|
|
|
def process_tile_order(
|
|
order: TileOrder,
|
|
stack: StageResult,
|
|
tile: int,
|
|
scale: int,
|
|
filters: List[TileCallback],
|
|
**kwargs,
|
|
) -> List[Image.Image]:
|
|
if order == TileOrder.grid:
|
|
logger.debug("using grid tile order with tile size: %s", tile)
|
|
return process_tile_stack(
|
|
stack, tile, scale, filters, generate_tile_grid, **kwargs
|
|
)
|
|
elif order == TileOrder.kernel:
|
|
logger.debug("using kernel tile order with tile size: %s", tile)
|
|
raise NotImplementedError()
|
|
elif order == TileOrder.spiral:
|
|
logger.debug("using spiral tile order with tile size: %s", tile)
|
|
return process_tile_stack(
|
|
stack, tile, scale, filters, generate_tile_spiral, **kwargs
|
|
)
|
|
else:
|
|
logger.warning("unknown tile order: %s", order)
|
|
raise ValueError()
|
|
|
|
|
|
def generate_tile_spiral(
|
|
width: int,
|
|
height: int,
|
|
tile: int,
|
|
overlap: float = 0.0,
|
|
) -> List[Tuple[int, int]]:
|
|
spacing = 1.0 - overlap
|
|
|
|
tile_increment = (
|
|
round(tile * spacing / 2) * 2
|
|
) # dividing and then multiplying by 2 ensures this will be an even number, which is necessary for the initial tile placement calculation
|
|
|
|
# calculate the number of tiles needed
|
|
if width > tile:
|
|
width_tile_target = 1 + ceil((width - tile) / tile_increment)
|
|
else:
|
|
width_tile_target = 1
|
|
if height > tile:
|
|
height_tile_target = 1 + ceil((height - tile) / tile_increment)
|
|
else:
|
|
height_tile_target = 1
|
|
|
|
# calculate the start position of the tiling
|
|
span_x = tile + (width_tile_target - 1) * tile_increment
|
|
span_y = tile + (height_tile_target - 1) * tile_increment
|
|
|
|
logger.debug("tiled image overlap: %s. Span: %s x %s", overlap, span_x, span_y)
|
|
|
|
tile_left = (
|
|
width - span_x
|
|
) / 2 # guaranteed to be an integer because width and span will both be even
|
|
tile_top = (
|
|
height - span_y
|
|
) / 2 # guaranteed to be an integer because width and span will both be even
|
|
|
|
logger.debug(
|
|
"image size %s x %s, tiling to %s x %s, starting at %s, %s",
|
|
width,
|
|
height,
|
|
width_tile_target,
|
|
height_tile_target,
|
|
tile_left,
|
|
tile_top,
|
|
)
|
|
|
|
tile_coords = []
|
|
|
|
# start walking from the north-west corner, heading east
|
|
class WalkState(Enum):
|
|
EAST = (1, 0)
|
|
SOUTH = (0, 1)
|
|
WEST = (-1, 0)
|
|
NORTH = (0, -1)
|
|
|
|
# initialize the tile_left placement
|
|
tile_left -= tile_increment
|
|
height_tile_target -= 1
|
|
|
|
for state in itertools.cycle(WalkState):
|
|
# This expression is stupid, but all it does is calculate the number of tiles we need in the appropriate direction
|
|
accum_tile_target = max(
|
|
map(
|
|
lambda coord, val: abs(coord * val),
|
|
state.value,
|
|
(width_tile_target, height_tile_target),
|
|
)
|
|
)
|
|
# check if done
|
|
if accum_tile_target == 0:
|
|
break
|
|
|
|
# reset tile count
|
|
accum_tiles = 0
|
|
while accum_tiles < accum_tile_target:
|
|
# move to the next
|
|
tile_left += tile_increment * state.value[0]
|
|
tile_top += tile_increment * state.value[1]
|
|
|
|
# add a tile
|
|
logger.debug("adding tile at %s:%s", tile_left, tile_top)
|
|
tile_coords.append((int(tile_left), int(tile_top)))
|
|
|
|
accum_tiles += 1
|
|
|
|
width_tile_target -= abs(state.value[0])
|
|
height_tile_target -= abs(state.value[1])
|
|
|
|
return tile_coords
|
|
|
|
|
|
def generate_tile_grid(
|
|
width: int,
|
|
height: int,
|
|
tile: int,
|
|
overlap: float = 0.0,
|
|
) -> List[Tuple[int, int, Image.Image]]:
|
|
adj_tile = int(float(tile) * (1.0 - overlap))
|
|
tiles_x = ceil(width / adj_tile)
|
|
tiles_y = ceil(height / adj_tile)
|
|
total = tiles_x * tiles_y
|
|
logger.debug(
|
|
"processing %s tiles (%s x %s) with adjusted size of %s, %s overlap",
|
|
total,
|
|
tiles_x,
|
|
tiles_y,
|
|
adj_tile,
|
|
overlap,
|
|
)
|
|
|
|
tiles: List[Tuple[int, int, Image.Image]] = []
|
|
|
|
for y in range(tiles_y):
|
|
for x in range(tiles_x):
|
|
left = x * adj_tile
|
|
top = y * adj_tile
|
|
|
|
tiles.append((int(left), int(top)))
|
|
|
|
return tiles
|
|
|
|
|
|
def get_result_tile(
|
|
result: StageResult,
|
|
origin: Tuple[int, int],
|
|
tile: Size,
|
|
) -> List[Image.Image]:
|
|
top, left = origin
|
|
return [
|
|
layer.crop((top, left, top + tile.height, left + tile.width))
|
|
for layer in result.as_images()
|
|
]
|
|
|
|
|
|
def add_margin(
|
|
stack: List[Image.Image],
|
|
left: int,
|
|
top: int,
|
|
right: int,
|
|
bottom: int,
|
|
left_margin: int,
|
|
top_margin: int,
|
|
right_margin: int,
|
|
bottom_margin: int,
|
|
tile: int,
|
|
noise_source,
|
|
fill_color,
|
|
) -> List[Image.Image]:
|
|
results = []
|
|
for source in stack:
|
|
base_image = source.crop(
|
|
(
|
|
left + left_margin,
|
|
top + top_margin,
|
|
right + right_margin,
|
|
bottom + bottom_margin,
|
|
)
|
|
)
|
|
tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color)
|
|
tile_image.paste(base_image, (left_margin, top_margin))
|
|
results.append(tile_image)
|
|
|
|
return results
|