diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index d4ae7694..00e05b12 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -114,7 +114,7 @@ def blend_inpaint( return result.images[0] - output = process_tile_order(stage.tile_order, source, SizeChart.auto, 1, [outpaint]) + output = process_tile_order(stage.tile_order, source, SizeChart.auto, 1, [outpaint], overlap=params.overlap) logger.info("final output image size: %s", output.size) return output diff --git a/api/onnx_web/chain/utils.py b/api/onnx_web/chain/utils.py index ce5443e3..855b3fa6 100644 --- a/api/onnx_web/chain/utils.py +++ b/api/onnx_web/chain/utils.py @@ -60,6 +60,55 @@ def get_tile_grads( return (grad_x, grad_y) +def blend_tiles( + tiles: List[Tuple[int, int, Image.Image]], + scale: int, + width: int, + height: int, + tile: int, + overlap: float, +): + adj_tile = int(float(tile) * overlap) + scaled_size = (height * scale, width * scale, 3) + count = np.zeros(scaled_size) + value = np.zeros(scaled_size) + ref = np.array(tiles[0][2]) + + for left, top, tile_image in tiles: + # histogram equalization + equalized = np.array(tile_image) + equalized = match_histograms(equalized, ref, channel_axis=-1) + + # gradient blending + points = [0, adj_tile * scale, (tile - adj_tile) * scale, (tile * scale) - 1] + grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height) + mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)] + mult_y = [np.interp(i, points, grad_y) for i in range(tile * scale)] + + mask = np.ones_like(equalized[:, :, 0]) * mult_x + mask = (mask.T * mult_y).T + for c in range(3): + equalized[:, :, c] = (equalized[:, :, c] * mask).astype(np.uint8) + + scaled_top = top * scale + scaled_left = left * scale + + # equalized size may be wrong/too much + scaled_bottom = min(scaled_top + equalized.shape[0], scaled_size[0]) + scaled_right = min(scaled_left + equalized.shape[1], scaled_size[1]) + + # accumulation + value[ + scaled_top : scaled_bottom, scaled_left : scaled_right, : + ] += equalized[0 : scaled_bottom - scaled_top, 0 : scaled_right - scaled_left, :] + count[ + scaled_top : scaled_bottom, scaled_left : scaled_right, : + ] += np.repeat(mask[0 : scaled_bottom - scaled_top, 0 : scaled_right - scaled_left, np.newaxis], 3, axis=2) + + pixels = np.where(count > 0, value / count, value) + return Image.fromarray(pixels) + + def process_tile_grid( source: Image.Image, tile: int, @@ -92,44 +141,7 @@ def process_tile_grid( tiles.append((left, top, tile_image)) - scaled_size = (height * scale, width * scale, 3) - count = np.zeros(scaled_size) - value = np.zeros(scaled_size) - ref = np.array(tiles[0][2]) - - for left, top, tile_image in tiles: - # histogram equalization - equalized = np.array(tile_image) - equalized = match_histograms(equalized, ref, channel_axis=-1) - - # gradient blending - points = [0, adj_tile * scale, (tile - adj_tile) * scale, (tile * scale) - 1] - grad_x, grad_y = get_tile_grads(left, top, adj_tile, width, height) - mult_x = [np.interp(i, points, grad_x) for i in range(tile * scale)] - mult_y = [np.interp(i, points, grad_y) for i in range(tile * scale)] - - mask = np.ones_like(equalized[:, :, 0]) * mult_x - mask = (mask.T * mult_y).T - for c in range(3): - equalized[:, :, c] = (equalized[:, :, c] * mask).astype(np.uint8) - - # accumulation - # equalized size may be wrong/too much - scaled_top = top * scale - scaled_left = left * scale - - scaled_bottom = min(scaled_top + equalized.shape[0], scaled_size[0]) - scaled_right = min(scaled_left + equalized.shape[1], scaled_size[1]) - - value[ - scaled_top : scaled_bottom, scaled_left : scaled_right, : - ] += equalized[0 : scaled_bottom - scaled_top, 0 : scaled_right - scaled_left, :] - count[ - scaled_top : scaled_bottom, scaled_left : scaled_right, : - ] += np.repeat(mask[0 : scaled_bottom - scaled_top, 0 : scaled_right - scaled_left, np.newaxis], 3, axis=2) - - pixels = np.where(count > 0, value / count, value) - return Image.fromarray(pixels) + return blend_tiles(tiles, scale, width, height, tile, adj_tile) def process_tile_spiral( @@ -144,15 +156,20 @@ def process_tile_spiral( raise ValueError("unsupported scale") width, height = source.size + + # spiral uses the previous run and needs a scratch texture for 3x memory image = Image.new("RGB", (width * scale, height * scale)) image.paste(source, (0, 0, width, height)) + tiles: List[Tuple[int, int, Image.Image]] = [] + tiles.append((0, 0, source)) + # tile tuples is source, multiply by scale for dest counter = 0 - tiles = generate_tile_spiral(width, height, tile, overlap=overlap) - for left, top in tiles: + tile_coords = generate_tile_spiral(width, height, tile, overlap=overlap) + for left, top in tile_coords: counter += 1 - logger.debug("processing tile %s of %s, %sx%s", counter, len(tiles), left, top) + logger.debug("processing tile %s of %s, %sx%s", counter, len(tile_coords), left, top) tile_image = image.crop((left, top, left + tile, top + tile)) tile_image = complete_tile(tile_image, tile) @@ -160,9 +177,9 @@ def process_tile_spiral( for filter in filters: tile_image = filter(tile_image, (left, top, tile)) - image.paste(tile_image, (left * scale, top * scale)) + tiles.append((left, top, tile_image)) - return image + return blend_tiles(tiles, scale, width, height, tile, overlap) def process_tile_order( diff --git a/api/onnx_web/diffusers/run.py b/api/onnx_web/diffusers/run.py index 9cc42a53..f52b1d64 100644 --- a/api/onnx_web/diffusers/run.py +++ b/api/onnx_web/diffusers/run.py @@ -217,7 +217,7 @@ def run_highres( size.height // highres.scale, highres.scale, [highres_tile], - overlap=0, + overlap=params.overlap, ) return image