import itertools from enum import Enum from logging import getLogger from math import ceil from typing import Any, Callable, List, Optional, Protocol, Tuple, Union import numpy as np from PIL import Image from ..image.noise_source import noise_source_histogram from ..params import Size, TileOrder from .result import StageResult # from skimage.exposure import match_histograms logger = getLogger(__name__) TileGenerator = Callable[[int, int, int, Optional[float]], List[Tuple[int, int]]] class TileCallback(Protocol): """ Definition for a tile job function. """ def __call__( self, sources: List[Image.Image], mask: Image.Image, dims: Tuple[int, int, int] ) -> StageResult: """ Run this stage against a single tile. """ pass def complete_tile( source: Image.Image, tile: int, ) -> Image.Image: """ TODO: clean up """ if source is None: return source if source.width < tile or source.height < tile: full_source =, (tile, tile)) full_source.paste(source) return full_source return source def needs_tile( max_tile: int, stage_tile: int, size: Optional[Size] = None, source: Optional[Image.Image] = None, ) -> bool: tile = min(max_tile, stage_tile) logger.trace( "checking image tile dimensions: %s, %s, %s", tile, source.width > tile or source.height > tile if source is not None else False, size.width > tile or size.height > tile if size is not None else False, ) if source is not None: return source.width > tile or source.height > tile if size is not None: return size.width > tile or size.height > tile return False def make_tile_grads( left: int, top: int, tile: int, width: int, height: int, ) -> Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float]]: grad_x = [0, 1, 1, 0] grad_y = [0, 1, 1, 0] if left <= 0: grad_x[0] = 1 if top <= 0: grad_y[0] = 1 if (left + tile) >= width: grad_x[3] = 1 if (top + tile) >= height: grad_y[3] = 1 return (grad_x, grad_y) def make_tile_mask( shape: Any, tile: Tuple[int, int], overlap: float, edges: Tuple[bool, bool, bool, bool], ) -> np.ndarray: mask = np.ones(shape) tile_h, tile_w = tile adj_tile_h = int(float(tile_h) * (1.0 - overlap)) adj_tile_w = int(float(tile_w) * (1.0 - overlap)) # sort gradient points p1_h = adj_tile_h - 1 p2_h = tile_h - adj_tile_h points_h = [-1, min(p1_h, p2_h), max(p1_h, p2_h), tile_h] p1_w = adj_tile_w - 1 p2_w = tile_w - adj_tile_w points_w = [-1, min(p1_w, p2_w), max(p1_w, p2_w), tile_w] # build gradients edge_t, edge_l, edge_b, edge_r = edges grad_x, grad_y = [int(not edge_l), 1, 1, int(not edge_r)], [ int(not edge_t), 1, 1, int(not edge_b), ] logger.debug("tile gradients: %s, %s, %s, %s", points_w, points_h, grad_x, grad_y) mult_x = [np.interp(i, points_w, grad_x) for i in range(tile_w)] mult_y = [np.interp(i, points_h, grad_y) for i in range(tile_h)] mask = ((mask * mult_x).T * mult_y).T return mask def get_channels(image: Union[np.ndarray, Image.Image]) -> int: if isinstance(image, np.ndarray): return image.shape[-1] if image.mode == "RGBA": return 4 elif image.mode == "RGB": return 3 elif image.mode == "L": return 1 raise ValueError("unknown image format") def blend_tiles( tiles: List[Tuple[int, int, Image.Image]], scale: int, width: int, height: int, tile: int, overlap: float, ): adj_tile = int(float(tile) * (1.0 - overlap)) logger.debug( "adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap ) channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles]) scaled_size = (height * scale, width * scale, channels) count = np.zeros(scaled_size) value = np.zeros(scaled_size) for left, top, tile_image in tiles: equalized = np.array(tile_image).astype(np.float32) mask = np.ones_like(equalized[:, :, 0]) if adj_tile < tile: # sort gradient points p1 = (adj_tile * scale) - 1 p2 = (tile - adj_tile - 1) * scale points = [-1, min(p1, p2), max(p1, p2), (tile * scale)] # gradient blending grad_x, grad_y = make_tile_grads(left, top, adj_tile, width, height) logger.debug("tile gradients: %s, %s, %s", points, grad_x, grad_y) mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)] mult_y = [np.interp(i, points, grad_y) for i in range(tile * scale)] mask = ((mask * mult_x).T * mult_y).T for c in range(3): equalized[:, :, c] = equalized[:, :, c] * mask scaled_top = top * scale scaled_left = left * scale # equalized size may be wrong/too much scaled_bottom = scaled_top + equalized.shape[0] scaled_right = scaled_left + equalized.shape[1] writable_top = max(scaled_top, 0) writable_left = max(scaled_left, 0) writable_bottom = min(scaled_bottom, scaled_size[0]) writable_right = min(scaled_right, scaled_size[1]) margin_top = writable_top - scaled_top margin_left = writable_left - scaled_left margin_bottom = writable_bottom - scaled_bottom margin_right = writable_right - scaled_right logger.debug( "tile broadcast shapes: %s, %s, %s, %s \n writing shapes: %s, %s, %s, %s", writable_top, writable_left, writable_bottom, writable_right, margin_top, equalized.shape[0] + margin_bottom, margin_left, equalized.shape[0] + margin_right, ) # accumulation value[ writable_top:writable_bottom, writable_left:writable_right, : ] += equalized[ margin_top : equalized.shape[0] + margin_bottom, margin_left : equalized.shape[1] + margin_right, :, ] count[ writable_top:writable_bottom, writable_left:writable_right, : ] += np.repeat( mask[ margin_top : equalized.shape[0] + margin_bottom, margin_left : equalized.shape[1] + margin_right, np.newaxis, ], channels, axis=2, ) logger.trace("mean tiles contributing to each pixel: %s", np.mean(count)) pixels = np.where(count > 0, value / count, value) return Image.fromarray(np.uint8(pixels)) def process_tile_stack( stack: StageResult, tile: int, scale: int, filters: List[TileCallback], tile_generator: TileGenerator, overlap: float = 0.5, **kwargs, ) -> List[Image.Image]: sources = stack.as_images() width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None) mask = kwargs.get("mask", kwargs.get("stage_mask", None)) noise_source = kwargs.get("noise_source", noise_source_histogram) fill_color = kwargs.get("fill_color", None) if not mask: tile_mask = None tiles: List[Tuple[int, int, Image.Image]] = [] tile_coords = tile_generator(width, height, tile, overlap) single_tile = len(tile_coords) == 1 for counter, (left, top) in enumerate(tile_coords): "processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top ) right = left + tile bottom = top + tile left_margin = right_margin = top_margin = bottom_margin = 0 needs_margin = False if left < 0: needs_margin = True left_margin = 0 - left if right > width: needs_margin = True right_margin = width - right if top < 0: needs_margin = True top_margin = 0 - top if bottom > height: needs_margin = True bottom_margin = height - bottom if single_tile: logger.debug("using single tile") tile_stack = sources if mask: tile_mask = mask elif needs_margin: logger.debug( "tiling with added margins: %s, %s, %s, %s", left_margin, top_margin, right_margin, bottom_margin, ) tile_stack = add_margin( stack.as_images(), left, top, right, bottom, left_margin, top_margin, right_margin, bottom_margin, tile, noise_source, fill_color, ) if mask: base_mask = mask.crop( ( left + left_margin, top + top_margin, right + right_margin, bottom + bottom_margin, ) ) tile_mask ="L", (tile, tile), color=0) tile_mask.paste(base_mask, (left_margin, top_margin)) else: logger.debug("tiling normally") tile_stack = get_result_tile(stack, (left, top), Size(tile, tile)) if mask: tile_mask = mask.crop((left, top, right, bottom)) for image_filter in filters: tile_stack = image_filter(tile_stack, tile_mask, (left, top, tile)) if isinstance(tile_stack, list): tile_stack = StageResult.from_images(tile_stack) tiles.append((left, top, tile_stack.as_images())) lefts, tops, stacks = list(zip(*tiles)) coords = list(zip(lefts, tops)) stacks = list(zip(*stacks)) result = [] for stack in stacks: stack_tiles = zip(coords, stack) stack_tiles = [(left, top, tile) for (left, top), tile in stack_tiles] result.append(blend_tiles(stack_tiles, scale, width, height, tile, overlap)) return result def process_tile_order( order: TileOrder, stack: StageResult, tile: int, scale: int, filters: List[TileCallback], **kwargs, ) -> List[Image.Image]: if order == TileOrder.grid: logger.debug("using grid tile order with tile size: %s", tile) return process_tile_stack( stack, tile, scale, filters, generate_tile_grid, **kwargs ) elif order == TileOrder.kernel: logger.debug("using kernel tile order with tile size: %s", tile) raise NotImplementedError() elif order == TileOrder.spiral: logger.debug("using spiral tile order with tile size: %s", tile) return process_tile_stack( stack, tile, scale, filters, generate_tile_spiral, **kwargs ) else: logger.warning("unknown tile order: %s", order) raise ValueError() def generate_tile_spiral( width: int, height: int, tile: int, overlap: float = 0.0, ) -> List[Tuple[int, int]]: spacing = 1.0 - overlap tile_increment = ( round(tile * spacing / 2) * 2 ) # dividing and then multiplying by 2 ensures this will be an even number, which is necessary for the initial tile placement calculation # calculate the number of tiles needed if width > tile: width_tile_target = 1 + ceil((width - tile) / tile_increment) else: width_tile_target = 1 if height > tile: height_tile_target = 1 + ceil((height - tile) / tile_increment) else: height_tile_target = 1 # calculate the start position of the tiling span_x = tile + (width_tile_target - 1) * tile_increment span_y = tile + (height_tile_target - 1) * tile_increment logger.debug("tiled image overlap: %s. Span: %s x %s", overlap, span_x, span_y) tile_left = ( width - span_x ) / 2 # guaranteed to be an integer because width and span will both be even tile_top = ( height - span_y ) / 2 # guaranteed to be an integer because width and span will both be even logger.debug( "image size %s x %s, tiling to %s x %s, starting at %s, %s", width, height, width_tile_target, height_tile_target, tile_left, tile_top, ) tile_coords = [] # start walking from the north-west corner, heading east class WalkState(Enum): EAST = (1, 0) SOUTH = (0, 1) WEST = (-1, 0) NORTH = (0, -1) # initialize the tile_left placement tile_left -= tile_increment height_tile_target -= 1 for state in itertools.cycle(WalkState): # This expression is stupid, but all it does is calculate the number of tiles we need in the appropriate direction accum_tile_target = max( map( lambda coord, val: abs(coord * val), state.value, (width_tile_target, height_tile_target), ) ) # check if done if accum_tile_target == 0: break # reset tile count accum_tiles = 0 while accum_tiles < accum_tile_target: # move to the next tile_left += tile_increment * state.value[0] tile_top += tile_increment * state.value[1] # add a tile logger.debug("adding tile at %s:%s", tile_left, tile_top) tile_coords.append((int(tile_left), int(tile_top))) accum_tiles += 1 width_tile_target -= abs(state.value[0]) height_tile_target -= abs(state.value[1]) return tile_coords def generate_tile_grid( width: int, height: int, tile: int, overlap: float = 0.0, ) -> List[Tuple[int, int, Image.Image]]: adj_tile = int(float(tile) * (1.0 - overlap)) tiles_x = ceil(width / adj_tile) tiles_y = ceil(height / adj_tile) total = tiles_x * tiles_y logger.debug( "processing %s tiles (%s x %s) with adjusted size of %s, %s overlap", total, tiles_x, tiles_y, adj_tile, overlap, ) tiles: List[Tuple[int, int, Image.Image]] = [] for y in range(tiles_y): for x in range(tiles_x): left = x * adj_tile top = y * adj_tile tiles.append((int(left), int(top))) return tiles def get_result_tile( result: StageResult, origin: Tuple[int, int], tile: Size, ) -> List[Image.Image]: top, left = origin return [ layer.crop((top, left, top + tile.height, left + tile.width)) for layer in result.as_images() ] def add_margin( stack: List[Image.Image], left: int, top: int, right: int, bottom: int, left_margin: int, top_margin: int, right_margin: int, bottom_margin: int, tile: int, noise_source, fill_color, ) -> List[Image.Image]: results = [] for source in stack: base_image = source.crop( ( left + left_margin, top + top_margin, right + right_margin, bottom + bottom_margin, ) ) tile_image = noise_source(base_image, (tile, tile), (0, 0), fill=fill_color) tile_image.paste(base_image, (left_margin, top_margin)) results.append(tile_image) return results