diff --git a/api/onnx_web/chain/blend_grid.py b/api/onnx_web/chain/blend_grid.py index 234f6c3b..5ca17151 100644 --- a/api/onnx_web/chain/blend_grid.py +++ b/api/onnx_web/chain/blend_grid.py @@ -3,7 +3,7 @@ from typing import Optional from PIL import Image -from ..params import ImageParams, SizeChart, StageParams +from ..params import ImageParams, Size, SizeChart, StageParams from ..server import ServerContext from ..worker import ProgressCallback, WorkerContext from .base import BaseStage @@ -36,9 +36,10 @@ class BlendGridStage(BaseStage): logger.info("combining source images using grid layout") images = sources.as_image() - size = images[0].size + ref_image = images[0] + size = Size(*ref_image.size) - output = Image.new("RGB", (size[0] * width, size[1] * height)) + output = Image.new(ref_image.mode, (size.width * width, size.height * height)) # TODO: labels if order is None: @@ -49,7 +50,7 @@ class BlendGridStage(BaseStage): y = i // width n = order[i] - output.paste(images[n], (x * size[0], y * size[1])) + output.paste(images[n], (x * size.width, y * size.height)) return StageResult(images=[*images, output]) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 4ebb1498..4486bbf6 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -30,7 +30,8 @@ class BlendMaskStage(BaseStage): ) -> StageResult: logger.info("blending image using mask") - mult_mask = Image.new("RGBA", stage_mask.size, color="black") + # TODO: does this need an alpha channel? + mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black") mult_mask.alpha_composite(stage_mask) mult_mask = mult_mask.convert("L") diff --git a/api/onnx_web/chain/result.py b/api/onnx_web/chain/result.py index 9bd7395d..ea19c850 100644 --- a/api/onnx_web/chain/result.py +++ b/api/onnx_web/chain/result.py @@ -52,4 +52,16 @@ class StageResult: if self.images is not None: return self.images - return [Image.fromarray(np.uint8(i), "RGB") for i in self.arrays] + return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays] + + +def shape_mode(arr: np.ndarray) -> str: + if len(arr.shape) != 3: + raise ValueError("unknown array format") + + if arr.shape[-1] == 3: + return "RGB" + elif arr.shape[-1] == 4: + return "RGBA" + + raise ValueError("unknown image format") \ No newline at end of file diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index b88d9628..6b738556 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -151,7 +151,9 @@ def blend_tiles( "adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap ) - scaled_size = (height * scale, width * scale, 3) + channels = max([4 if tile_image.mode == "RGBA" else 3 for _left, _top, tile_image in tiles]) + scaled_size = (height * scale, width * scale, channels) + count = np.zeros(scaled_size) value = np.zeros(scaled_size) @@ -221,7 +223,7 @@ def blend_tiles( margin_left : equalized.shape[1] + margin_right, np.newaxis, ], - 3, + channels, axis=2, ) diff --git a/api/onnx_web/image/mask_filter.py b/api/onnx_web/image/mask_filter.py index 82a19dfa..967fce1f 100644 --- a/api/onnx_web/image/mask_filter.py +++ b/api/onnx_web/image/mask_filter.py @@ -8,7 +8,7 @@ def mask_filter_none( ) -> Image.Image: width, height = dims - noise = Image.new("RGB", (width, height), fill) + noise = Image.new(mask.mode, (width, height), fill) noise.paste(mask, origin) return noise diff --git a/api/onnx_web/image/noise_source.py b/api/onnx_web/image/noise_source.py index a1dd47f1..2c260f14 100644 --- a/api/onnx_web/image/noise_source.py +++ b/api/onnx_web/image/noise_source.py @@ -17,21 +17,21 @@ def noise_source_fill_edge( """ width, height = dims - noise = Image.new("RGB", (width, height), fill) + noise = Image.new(source.mode, (width, height), fill) noise.paste(source, origin) return noise def noise_source_fill_mask( - _source: Image.Image, dims: Point, _origin: Point, fill="white", **kw + source: Image.Image, dims: Point, _origin: Point, fill="white", **kw ) -> Image.Image: """ Fill the whole canvas, no source or noise. """ width, height = dims - noise = Image.new("RGB", (width, height), fill) + noise = Image.new(source.mode, (width, height), fill) return noise @@ -52,7 +52,7 @@ def noise_source_gaussian( def noise_source_uniform( - _source: Image.Image, dims: Point, _origin: Point, **kw + source: Image.Image, dims: Point, _origin: Point, **kw ) -> Image.Image: width, height = dims size = width * height @@ -61,6 +61,7 @@ def noise_source_uniform( noise_g = random.uniform(0, 256, size=size) noise_b = random.uniform(0, 256, size=size) + # needs to be RGB for pixel manipulation noise = Image.new("RGB", (width, height)) for x in range(width): @@ -68,11 +69,11 @@ def noise_source_uniform( i = get_pixel_index(x, y, width) noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) - return noise + return noise.convert(source.mode) def noise_source_normal( - _source: Image.Image, dims: Point, _origin: Point, **kw + source: Image.Image, dims: Point, _origin: Point, **kw ) -> Image.Image: width, height = dims size = width * height @@ -81,6 +82,7 @@ def noise_source_normal( noise_g = random.normal(128, 32, size=size) noise_b = random.normal(128, 32, size=size) + # needs to be RGB for pixel manipulation noise = Image.new("RGB", (width, height)) for x in range(width): @@ -88,7 +90,7 @@ def noise_source_normal( i = get_pixel_index(x, y, width) noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) - return noise + return noise.convert(source.mode) def noise_source_histogram( @@ -112,6 +114,7 @@ def noise_source_histogram( 256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size ) + # needs to be RGB for pixel manipulation noise = Image.new("RGB", (width, height)) for x in range(width): @@ -119,4 +122,4 @@ def noise_source_histogram( i = get_pixel_index(x, y, width) noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i])) - return noise + return noise.convert(source.mode) diff --git a/api/onnx_web/image/utils.py b/api/onnx_web/image/utils.py index 4e2f3a7a..a1264ff0 100644 --- a/api/onnx_web/image/utils.py +++ b/api/onnx_web/image/utils.py @@ -20,7 +20,7 @@ def expand_image( size = tuple(size) origin = (expand.left, expand.top) - full_source = Image.new("RGB", size, fill) + full_source = Image.new(source.mode, size, fill) full_source.paste(source, origin) # new mask pixels need to be filled with white so they will be replaced