diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 6b738556..a6705e5c 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -2,7 +2,7 @@ import itertools from enum import Enum from logging import getLogger from math import ceil -from typing import Any, Callable, List, Optional, Protocol, Tuple +from typing import Any, Callable, List, Optional, Protocol, Tuple, Union import numpy as np from PIL import Image @@ -138,6 +138,20 @@ def make_tile_mask( return mask +def get_channels(image: Union[np.ndarray, Image.Image]) -> int: + if isinstance(image, np.ndarray): + return image.shape[-1] + + if image.mode == "RGBA": + return 4 + elif image.mode == "RGB": + return 3 + elif image.mode == "L": + return 1 + + raise ValueError("unknown image format") + + def blend_tiles( tiles: List[Tuple[int, int, Image.Image]], scale: int, @@ -151,7 +165,7 @@ def blend_tiles( "adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap ) - channels = max([4 if tile_image.mode == "RGBA" else 3 for _left, _top, tile_image in tiles]) + channels = max([get_channels(tile_image) for _left, _top, tile_image in tiles]) scaled_size = (height * scale, width * scale, channels) count = np.zeros(scaled_size)