diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index 667a541d..fb855a72 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -2,7 +2,9 @@ from logging import getLogger from math import ceil from typing import List, Protocol, Tuple +import numpy as np from PIL import Image +from skimage.exposure import match_histograms from ..params import TileOrder @@ -38,20 +40,23 @@ def process_tile_grid( tile: int, scale: int, filters: List[TileCallback], + overlap: float = 0.5, **kwargs, ) -> Image.Image: width, height = source.size - image = Image.new(source.mode, (width * scale, height * scale)) - tiles_x = ceil(width / tile) - tiles_y = ceil(height / tile) + adj_tile = tile * overlap + tiles_x = ceil(width / adj_tile) + tiles_y = ceil(height / adj_tile) total = tiles_x * tiles_y + tiles: List[Tuple[int, int, Image.Image]] = [] + for y in range(tiles_y): for x in range(tiles_x): idx = (y * tiles_x) + x - left = x * tile - top = y * tile + left = x * adj_tile + top = y * adj_tile logger.debug("processing tile %s of %s, %s.%s", idx + 1, total, y, x) tile_image = source.crop((left, top, left + tile, top + tile)) @@ -60,9 +65,24 @@ def process_tile_grid( for filter in filters: tile_image = filter(tile_image, (left, top, tile)) - image.paste(tile_image, (left * scale, top * scale)) + tiles.append((left, top, tile_image)) - return image + scaled_size = (width * scale, height * scale, 3) + count = np.zeros_like(scaled_size) + value = np.zeros_like(scaled_size) + ref = tiles[0][2] + + for left, top, tile_image in tiles: + equalized = match_histograms(tile_image, ref, channel_axis=-1) + value[ + top * scale : (top + tile) * scale, left * scale : (left + tile) * scale, : + ] += np.array(equalized) + count[ + top * scale : (top + tile) * scale, left * scale : (left + tile) * scale, : + ] += 1 + + pixels = np.where(count > 0, value / count, value) + return Image.fromarray(pixels) def process_tile_spiral(