2023-01-28 23:09:19 +00:00
from logging import getLogger
2023-04-04 02:39:10 +00:00
from math import ceil
2023-07-02 23:14:52 +00:00
from typing import List , Optional , Protocol , Tuple
2023-07-07 02:46:36 +00:00
from enum import Enum
import itertools
2023-01-28 18:42:02 +00:00
2023-06-02 06:14:40 +00:00
import numpy as np
2023-02-05 13:53:26 +00:00
from PIL import Image
2023-07-02 23:14:52 +00:00
from . . params import Size , TileOrder
2023-07-07 02:46:36 +00:00
from . . image . noise_source import noise_source_histogram
2023-02-12 00:00:18 +00:00
2023-06-04 02:00:59 +00:00
# from skimage.exposure import match_histograms
2023-01-28 23:09:19 +00:00
logger = getLogger ( __name__ )
2023-01-28 18:42:02 +00:00
class TileCallback ( Protocol ) :
2023-03-24 13:14:19 +00:00
"""
Definition for a tile job function .
"""
2023-01-28 18:42:02 +00:00
def __call__ ( self , image : Image . Image , dims : Tuple [ int , int , int ] ) - > Image . Image :
2023-03-24 13:14:19 +00:00
"""
Run this stage against a single tile .
"""
2023-01-28 18:42:02 +00:00
pass
2023-01-28 14:19:40 +00:00
2023-04-29 19:23:00 +00:00
def complete_tile (
source : Image . Image ,
tile : int ,
) - > Image . Image :
2023-07-02 23:25:15 +00:00
if source is None :
return source
2023-04-29 19:23:00 +00:00
if source . width < tile or source . height < tile :
full_source = Image . new ( source . mode , ( tile , tile ) )
full_source . paste ( source )
return full_source
return source
2023-07-02 23:14:52 +00:00
def needs_tile (
max_tile : int ,
stage_tile : int ,
size : Optional [ Size ] = None ,
source : Optional [ Image . Image ] = None ,
) - > bool :
tile = min ( max_tile , stage_tile )
if source is not None :
return source . width > tile or source . height > tile
if size is not None :
return size . width > tile or size . height > tile
return False
2023-06-03 14:51:44 +00:00
def get_tile_grads (
2023-06-03 17:22:50 +00:00
left : int ,
top : int ,
tile : int ,
width : int ,
height : int ,
2023-06-03 14:51:44 +00:00
) - > Tuple [ Tuple [ float , float , float , float ] , Tuple [ float , float , float , float ] ] :
2023-06-03 18:17:20 +00:00
grad_x = [ 0 , 1 , 1 , 0 ]
grad_y = [ 0 , 1 , 1 , 0 ]
2023-06-03 14:51:44 +00:00
2023-06-03 18:17:20 +00:00
if left < = 0 :
grad_x [ 0 ] = 1
2023-06-03 14:51:44 +00:00
2023-06-03 18:17:20 +00:00
if top < = 0 :
grad_y [ 0 ] = 1
2023-06-03 14:51:44 +00:00
2023-06-03 17:22:50 +00:00
if ( left + tile ) > = width :
2023-06-03 18:17:20 +00:00
grad_x [ 3 ] = 1
2023-06-03 14:51:44 +00:00
2023-06-03 17:22:50 +00:00
if ( top + tile ) > = height :
2023-06-03 18:17:20 +00:00
grad_y [ 3 ] = 1
2023-06-03 14:51:44 +00:00
return ( grad_x , grad_y )
2023-06-04 01:35:21 +00:00
def blend_tiles (
2023-06-04 01:35:33 +00:00
tiles : List [ Tuple [ int , int , Image . Image ] ] ,
scale : int ,
width : int ,
height : int ,
tile : int ,
overlap : float ,
2023-06-04 01:35:21 +00:00
) :
2023-06-04 01:56:56 +00:00
adj_tile = int ( float ( tile ) * ( 1.0 - overlap ) )
2023-07-03 13:05:57 +00:00
logger . debug (
2023-06-08 12:20:19 +00:00
" adjusting tile size from %s to %s based on %s overlap " , tile , adj_tile , overlap
)
2023-06-08 12:20:03 +00:00
2023-06-03 18:28:09 +00:00
scaled_size = ( height * scale , width * scale , 3 )
2023-06-03 14:51:44 +00:00
count = np . zeros ( scaled_size )
value = np . zeros ( scaled_size )
2023-06-02 06:14:40 +00:00
for left , top , tile_image in tiles :
2023-06-03 14:51:44 +00:00
# histogram equalization
2023-06-04 01:38:16 +00:00
equalized = np . array ( tile_image ) . astype ( np . float32 )
2023-06-09 22:24:08 +00:00
mask = np . ones_like ( equalized [ : , : , 0 ] )
2023-06-03 14:51:44 +00:00
2023-06-09 22:24:08 +00:00
if adj_tile < tile :
2023-06-10 20:20:17 +00:00
# sort gradient points
p1 = adj_tile * scale
p2 = ( tile - adj_tile ) * scale
points = [ 0 , min ( p1 , p2 ) , max ( p1 , p2 ) , tile * scale ]
2023-06-09 22:24:08 +00:00
# gradient blending
grad_x , grad_y = get_tile_grads ( left , top , adj_tile , width , height )
2023-07-03 13:05:57 +00:00
logger . debug ( " tile gradients: %s , %s , %s " , points , grad_x , grad_y )
2023-06-08 12:09:10 +00:00
2023-06-09 22:24:08 +00:00
mult_x = [ np . interp ( i , points , grad_x ) for i in range ( tile * scale ) ]
mult_y = [ np . interp ( i , points , grad_y ) for i in range ( tile * scale ) ]
2023-06-03 14:51:44 +00:00
2023-06-09 22:24:08 +00:00
mask = ( ( mask * mult_x ) . T * mult_y ) . T
for c in range ( 3 ) :
equalized [ : , : , c ] = equalized [ : , : , c ] * mask
2023-06-03 14:51:44 +00:00
2023-06-03 18:50:27 +00:00
scaled_top = top * scale
scaled_left = left * scale
2023-06-04 01:35:21 +00:00
# equalized size may be wrong/too much
2023-07-07 02:46:36 +00:00
scaled_bottom = scaled_top + equalized . shape [ 0 ]
scaled_right = scaled_left + equalized . shape [ 1 ]
writable_top = max ( scaled_top , 0 )
writable_left = max ( scaled_left , 0 )
writable_bottom = min ( scaled_bottom , scaled_size [ 0 ] )
writable_right = min ( scaled_right , scaled_size [ 1 ] )
margin_top = writable_top - scaled_top
margin_left = writable_left - scaled_left
margin_bottom = writable_bottom - scaled_bottom
margin_right = writable_right - scaled_right
logger . debug (
2023-06-08 12:20:19 +00:00
" tile broadcast shapes: %s , %s , %s , %s " ,
2023-07-07 02:46:36 +00:00
writable_top ,
writable_left ,
writable_bottom ,
writable_right ,
)
logger . debug (
" writing shapes: %s , %s , %s , %s " ,
margin_top ,
equalized . shape [ 0 ] + margin_bottom ,
2023-06-08 12:20:19 +00:00
scaled_left ,
2023-07-07 02:46:36 +00:00
equalized . shape [ 0 ] + margin_right ,
2023-06-08 12:20:19 +00:00
)
2023-06-03 18:28:09 +00:00
2023-06-04 01:35:21 +00:00
# accumulation
2023-07-07 02:46:36 +00:00
value [ writable_top : writable_bottom , writable_left : writable_right , : ] + = equalized [
margin_top : equalized . shape [ 0 ] + margin_bottom , margin_left : equalized . shape [ 1 ] + margin_right , :
2023-06-04 01:35:33 +00:00
]
2023-07-07 02:46:36 +00:00
count [ writable_top : writable_bottom , writable_left : writable_right , : ] + = np . repeat (
2023-06-04 01:35:33 +00:00
mask [
2023-07-07 02:46:36 +00:00
margin_top : equalized . shape [ 0 ] + margin_bottom ,
margin_left : equalized . shape [ 1 ] + margin_right ,
2023-06-04 01:35:33 +00:00
np . newaxis ,
] ,
3 ,
axis = 2 ,
)
2023-01-28 14:19:40 +00:00
2023-07-03 16:39:48 +00:00
logger . trace ( " mean tiles contributing to each pixel: %s " , np . mean ( count ) )
2023-06-02 06:14:40 +00:00
pixels = np . where ( count > 0 , value / count , value )
2023-06-04 01:38:16 +00:00
return Image . fromarray ( np . uint8 ( pixels ) )
2023-01-29 05:46:36 +00:00
2023-06-04 01:35:21 +00:00
def process_tile_grid (
source : Image . Image ,
tile : int ,
scale : int ,
filters : List [ TileCallback ] ,
2023-06-08 12:09:10 +00:00
overlap : float = 0.0 ,
2023-06-04 01:35:21 +00:00
* * kwargs ,
) - > Image . Image :
2023-07-02 23:25:15 +00:00
width , height = kwargs . get ( " size " , source . size if source else None )
2023-06-04 01:35:21 +00:00
2023-06-04 01:56:56 +00:00
adj_tile = int ( float ( tile ) * ( 1.0 - overlap ) )
2023-06-04 01:35:21 +00:00
tiles_x = ceil ( width / adj_tile )
tiles_y = ceil ( height / adj_tile )
total = tiles_x * tiles_y
2023-07-02 10:15:01 +00:00
logger . debug (
" processing %s tiles ( %s x %s ) with adjusted size of %s , %s overlap " ,
total ,
tiles_x ,
tiles_y ,
adj_tile ,
overlap ,
)
2023-06-04 01:35:21 +00:00
tiles : List [ Tuple [ int , int , Image . Image ] ] = [ ]
for y in range ( tiles_y ) :
for x in range ( tiles_x ) :
idx = ( y * tiles_x ) + x
left = x * adj_tile
top = y * adj_tile
2023-06-06 04:47:32 +00:00
logger . info ( " processing tile %s of %s , %s . %s " , idx + 1 , total , y , x )
2023-06-04 01:35:21 +00:00
2023-07-02 23:25:15 +00:00
tile_image = (
source . crop ( ( left , top , left + tile , top + tile ) ) if source else None
)
2023-06-04 01:35:21 +00:00
tile_image = complete_tile ( tile_image , tile )
for filter in filters :
tile_image = filter ( tile_image , ( left , top , tile ) )
tiles . append ( ( left , top , tile_image ) )
2023-06-04 01:56:56 +00:00
return blend_tiles ( tiles , scale , width , height , tile , overlap )
2023-06-04 01:35:21 +00:00
2023-01-29 05:46:36 +00:00
def process_tile_spiral (
source : Image . Image ,
tile : int ,
scale : int ,
filters : List [ TileCallback ] ,
overlap : float = 0.5 ,
2023-02-12 00:00:18 +00:00
* * kwargs ,
2023-01-29 05:46:36 +00:00
) - > Image . Image :
2023-07-02 23:25:15 +00:00
width , height = kwargs . get ( " size " , source . size if source else None )
2023-06-04 01:35:21 +00:00
# spiral uses the previous run and needs a scratch texture for 3x memory
2023-01-29 05:58:04 +00:00
2023-06-04 01:35:21 +00:00
tiles : List [ Tuple [ int , int , Image . Image ] ] = [ ]
2023-01-29 05:46:36 +00:00
# tile tuples is source, multiply by scale for dest
counter = 0
2023-06-04 01:35:21 +00:00
tile_coords = generate_tile_spiral ( width , height , tile , overlap = overlap )
for left , top in tile_coords :
2023-01-29 05:46:36 +00:00
counter + = 1
2023-06-06 04:47:32 +00:00
logger . info (
2023-06-04 01:35:33 +00:00
" processing tile %s of %s , %s x %s " , counter , len ( tile_coords ) , left , top
)
2023-07-07 02:46:36 +00:00
right = left + tile
bottom = top + tile
left_margin = right_margin = top_margin = bottom_margin = 0
needs_margin = False
if left < 0 :
needs_margin = True
left_margin = 0 - left
if right > width :
needs_margin = True
right_margin = width - right
if top < 0 :
needs_margin = True
top_margin = 0 - top
if bottom > height :
needs_margin = True
bottom_margin = height - bottom
if needs_margin :
base_image = source . crop ( ( left + left_margin , top + top_margin , right - right_margin , bottom - bottom_margin ) ) if source else None
tile_image = noise_source_histogram ( base_image , ( tile , tile ) , ( 0 , 0 ) )
tile_image . paste ( base_image , ( left_margin , top_margin ) )
else :
tile_image = source . crop ( ( left , top , right , bottom ) ) if source else None
for image_filter in filters :
tile_image = image_filter ( tile_image , ( left , top , tile ) )
2023-01-29 05:46:36 +00:00
2023-06-04 01:35:21 +00:00
tiles . append ( ( left , top , tile_image ) )
2023-01-29 05:46:36 +00:00
2023-06-04 01:35:21 +00:00
return blend_tiles ( tiles , scale , width , height , tile , overlap )
2023-02-12 00:00:18 +00:00
def process_tile_order (
order : TileOrder ,
source : Image . Image ,
tile : int ,
scale : int ,
filters : List [ TileCallback ] ,
* * kwargs ,
) - > Image . Image :
if order == TileOrder . grid :
logger . debug ( " using grid tile order with tile size: %s " , tile )
return process_tile_grid ( source , tile , scale , filters , * * kwargs )
elif order == TileOrder . kernel :
logger . debug ( " using kernel tile order with tile size: %s " , tile )
raise NotImplementedError ( )
elif order == TileOrder . spiral :
logger . debug ( " using spiral tile order with tile size: %s " , tile )
return process_tile_spiral ( source , tile , scale , filters , * * kwargs )
2023-03-01 03:44:52 +00:00
else :
logger . warn ( " unknown tile order: %s " , order )
raise ValueError ( )
2023-04-04 02:39:10 +00:00
def generate_tile_spiral (
width : int ,
height : int ,
tile : int ,
2023-04-05 01:02:13 +00:00
overlap : float = 0.0 ,
2023-04-04 02:39:10 +00:00
) - > List [ Tuple [ int , int ] ] :
2023-04-05 01:02:13 +00:00
spacing = 1.0 - overlap
2023-07-07 02:46:36 +00:00
tile_increment = round ( tile * spacing / 2 ) * 2 #dividing and then multiplying by 2 ensures this will be an even number, which is necessary for the initial tile placement calculation
#calculate the number of tiles needed
width_tile_target = 1
height_tile_target = 1
if width > tile :
width_tile_target = 1 + ceil ( ( width - tile ) / tile_increment )
if height > tile :
height_tile_target = 1 + ceil ( ( height - tile ) / tile_increment )
#calculate the start position of the tiling
span_x = tile + ( width_tile_target - 1 ) * tile_increment
span_y = tile + ( height_tile_target - 1 ) * tile_increment
tile_left = ( width - span_x ) / 2 #guaranteed to be an integer because width and span will both be even
tile_top = ( height - span_y ) / 2 #guaranteed to be an integer because width and span will both be even
logger . debug (
" image size %s x %s , tiling to %s x %s , starting at %s , %s " ,
width ,
height ,
width_tile_target ,
height_tile_target ,
tile_left ,
tile_top
)
tile_coords = [ ]
2023-04-04 02:39:10 +00:00
# start walking from the north-west corner, heading east
2023-07-07 02:46:36 +00:00
class WalkState ( Enum ) :
EAST = ( 1 , 0 )
SOUTH = ( 0 , 1 )
WEST = ( - 1 , 0 )
NORTH = ( 0 , - 1 )
#initialize the tile_left placement
tile_left - = tile_increment
height_tile_target - = 1
for state in itertools . cycle ( WalkState ) :
#This expression is stupid, but all it does is calculate the number of tiles we need in the appropriate direction
accum_tile_target = max ( map ( lambda coord , val : abs ( coord * val ) , state . value , ( width_tile_target , height_tile_target ) ) )
#check if done
if accum_tile_target == 0 :
break
#reset tile count
accum_tiles = 0
while accum_tiles < accum_tile_target :
# move to the next
tile_left + = tile_increment * state . value [ 0 ]
tile_top + = tile_increment * state . value [ 1 ]
2023-04-04 02:39:10 +00:00
# add a tile
2023-07-07 02:46:36 +00:00
logger . debug (
" adding tile at %s : %s " ,
2023-04-04 02:39:10 +00:00
tile_left ,
2023-07-07 02:46:36 +00:00
tile_top
2023-04-04 02:39:10 +00:00
)
2023-04-06 13:52:27 +00:00
tile_coords . append ( ( int ( tile_left ) , int ( tile_top ) ) )
2023-04-04 02:39:10 +00:00
2023-07-07 02:46:36 +00:00
accum_tiles + = 1
width_tile_target - = abs ( state . value [ 0 ] )
height_tile_target - = abs ( state . value [ 1 ] )
2023-04-04 02:39:10 +00:00
return tile_coords