1
0
Fork 0
onnx-web/api/onnx_web/chain/utils.py

263 lines
7.2 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from logging import getLogger
2023-04-04 02:39:10 +00:00
from math import ceil
from typing import List, Protocol, Tuple
import numpy as np
2023-02-05 13:53:26 +00:00
from PIL import Image
from skimage.exposure import match_histograms
2023-02-05 13:53:26 +00:00
from ..params import TileOrder
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
class TileCallback(Protocol):
"""
Definition for a tile job function.
"""
def __call__(self, image: Image.Image, dims: Tuple[int, int, int]) -> Image.Image:
"""
Run this stage against a single tile.
"""
pass
def complete_tile(
source: Image.Image,
tile: int,
) -> Image.Image:
if source.width < tile or source.height < tile:
full_source = Image.new(source.mode, (tile, tile))
full_source.paste(source)
return full_source
return source
2023-06-03 14:51:44 +00:00
def get_tile_grads(
2023-06-03 17:22:50 +00:00
left: int,
top: int,
tile: int,
width: int,
height: int,
2023-06-03 14:51:44 +00:00
) -> Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float]]:
grad_x = [0, 1, 1, 0]
grad_y = [0, 1, 1, 0]
2023-06-03 14:51:44 +00:00
if left <= 0:
grad_x[0] = 1
2023-06-03 14:51:44 +00:00
if top <= 0:
grad_y[0] = 1
2023-06-03 14:51:44 +00:00
2023-06-03 17:22:50 +00:00
if (left + tile) >= width:
grad_x[3] = 1
2023-06-03 14:51:44 +00:00
2023-06-03 17:22:50 +00:00
if (top + tile) >= height:
grad_y[3] = 1
2023-06-03 14:51:44 +00:00
return (grad_x, grad_y)
def process_tile_grid(
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
overlap: float = 0.5,
**kwargs,
) -> Image.Image:
width, height = source.size
2023-06-03 14:51:44 +00:00
adj_tile = int(float(tile) * overlap)
tiles_x = ceil(width / adj_tile)
tiles_y = ceil(height / adj_tile)
total = tiles_x * tiles_y
tiles: List[Tuple[int, int, Image.Image]] = []
for y in range(tiles_y):
for x in range(tiles_x):
idx = (y * tiles_x) + x
left = x * adj_tile
top = y * adj_tile
2023-03-17 03:29:07 +00:00
logger.debug("processing tile %s of %s, %s.%s", idx + 1, total, y, x)
tile_image = source.crop((left, top, left + tile, top + tile))
2023-04-29 19:55:41 +00:00
tile_image = complete_tile(tile_image, tile)
for filter in filters:
tile_image = filter(tile_image, (left, top, tile))
tiles.append((left, top, tile_image))
scaled_size = (width * scale, height * scale, 3)
2023-06-03 14:51:44 +00:00
count = np.zeros(scaled_size)
value = np.zeros(scaled_size)
ref = np.array(tiles[0][2])
for left, top, tile_image in tiles:
2023-06-03 14:51:44 +00:00
# histogram equalization
equalized = np.array(tile_image)
equalized = match_histograms(equalized, ref, channel_axis=-1)
# gradient blending
2023-06-03 17:22:50 +00:00
points = [0, adj_tile * scale, (tile - adj_tile) * scale, (tile * scale) - 1]
grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height)
2023-06-03 17:22:50 +00:00
mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)]
mult_y = [np.interp(i, points, grad_y) for i in range(tile * scale)]
2023-06-03 14:51:44 +00:00
2023-06-03 17:22:50 +00:00
mask = np.ones_like(equalized[:, :, 0]) * mult_x
2023-06-03 14:51:44 +00:00
mask = (mask.T * mult_y).T
2023-06-03 17:22:50 +00:00
for c in range(3):
equalized[:, :, c] = (equalized[:, :, c] * mask).astype(np.uint8)
2023-06-03 14:51:44 +00:00
# accumulation
value[
top * scale : (top * scale) + equalized.shape[0], left * scale : (left * scale) + equalized.shape[1], :
2023-06-03 14:51:44 +00:00
] += equalized
count[
top * scale : (top * scale) + equalized.shape[0], left * scale : (left * scale) + equalized.shape[1], :
] += np.repeat(mask, 3, axis=2)
pixels = np.where(count > 0, value / count, value)
return Image.fromarray(pixels)
def process_tile_spiral(
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
overlap: float = 0.5,
**kwargs,
) -> Image.Image:
if scale != 1:
2023-02-20 04:10:35 +00:00
raise ValueError("unsupported scale")
width, height = source.size
2023-02-05 13:53:26 +00:00
image = Image.new("RGB", (width * scale, height * scale))
image.paste(source, (0, 0, width, height))
# tile tuples is source, multiply by scale for dest
counter = 0
tiles = generate_tile_spiral(width, height, tile, overlap=overlap)
for left, top in tiles:
counter += 1
2023-03-17 03:29:07 +00:00
logger.debug("processing tile %s of %s, %sx%s", counter, len(tiles), left, top)
tile_image = image.crop((left, top, left + tile, top + tile))
2023-04-29 19:55:41 +00:00
tile_image = complete_tile(tile_image, tile)
for filter in filters:
tile_image = filter(tile_image, (left, top, tile))
image.paste(tile_image, (left * scale, top * scale))
return image
def process_tile_order(
order: TileOrder,
source: Image.Image,
tile: int,
scale: int,
filters: List[TileCallback],
**kwargs,
) -> Image.Image:
if order == TileOrder.grid:
logger.debug("using grid tile order with tile size: %s", tile)
return process_tile_grid(source, tile, scale, filters, **kwargs)
elif order == TileOrder.kernel:
logger.debug("using kernel tile order with tile size: %s", tile)
raise NotImplementedError()
elif order == TileOrder.spiral:
logger.debug("using spiral tile order with tile size: %s", tile)
return process_tile_spiral(source, tile, scale, filters, **kwargs)
else:
logger.warn("unknown tile order: %s", order)
raise ValueError()
2023-04-04 02:39:10 +00:00
def generate_tile_spiral(
width: int,
height: int,
tile: int,
overlap: float = 0.0,
2023-04-04 02:39:10 +00:00
) -> List[Tuple[int, int]]:
spacing = 1.0 - overlap
2023-04-04 02:39:10 +00:00
# round dims up to nearest tiles
tile_width = ceil(width / tile)
tile_height = ceil(height / tile)
# start walking from the north-west corner, heading east
dir_height = 0
dir_width = 1
walk_height = tile_height
walk_width = tile_width
accum_height = 0
accum_width = 0
tile_top = 0
tile_left = 0
tile_coords = []
while walk_width > 0 and walk_height > 0:
# exhaust the current direction, then turn
while accum_width < walk_width and accum_height < walk_height:
# add a tile
logger.trace(
2023-04-29 20:40:26 +00:00
"adding tile at %s:%s, %s:%s, %s:%s, %s",
2023-04-04 02:39:10 +00:00
tile_left,
tile_top,
accum_width,
accum_height,
walk_width,
walk_height,
spacing,
2023-04-04 02:39:10 +00:00
)
tile_coords.append((int(tile_left), int(tile_top)))
2023-04-04 02:39:10 +00:00
# move to the next
tile_top += dir_height * spacing * tile
tile_left += dir_width * spacing * tile
2023-04-04 02:39:10 +00:00
accum_height += abs(dir_height * spacing)
accum_width += abs(dir_width * spacing)
2023-04-04 02:39:10 +00:00
# reset for the next direction
accum_height = 0
accum_width = 0
# why tho
tile_top -= dir_height
tile_left -= dir_width
# turn right
if dir_width == 1 and dir_height == 0:
dir_width = 0
dir_height = 1
elif dir_width == 0 and dir_height == 1:
dir_width = -1
dir_height = 0
elif dir_width == -1 and dir_height == 0:
dir_width = 0
dir_height = -1
elif dir_width == 0 and dir_height == -1:
dir_width = 1
dir_height = 0
# step to the next tile as part of the turn
tile_top += dir_height
tile_left += dir_width
# shrink the last direction
walk_height -= abs(dir_height)
walk_width -= abs(dir_width)
return tile_coords