1
0
Fork 0

fix(api): better handling of alpha channels

This commit is contained in:
Sean Sube 2023-11-25 18:52:47 -06:00
parent c134edf4b3
commit 1c3b2f8dfc
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 37 additions and 18 deletions

View File

@ -3,7 +3,7 @@ from typing import Optional
from PIL import Image from PIL import Image
from ..params import ImageParams, SizeChart, StageParams from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage from .base import BaseStage
@ -36,9 +36,10 @@ class BlendGridStage(BaseStage):
logger.info("combining source images using grid layout") logger.info("combining source images using grid layout")
images = sources.as_image() images = sources.as_image()
size = images[0].size ref_image = images[0]
size = Size(*ref_image.size)
output = Image.new("RGB", (size[0] * width, size[1] * height)) output = Image.new(ref_image.mode, (size.width * width, size.height * height))
# TODO: labels # TODO: labels
if order is None: if order is None:
@ -49,7 +50,7 @@ class BlendGridStage(BaseStage):
y = i // width y = i // width
n = order[i] n = order[i]
output.paste(images[n], (x * size[0], y * size[1])) output.paste(images[n], (x * size.width, y * size.height))
return StageResult(images=[*images, output]) return StageResult(images=[*images, output])

View File

@ -30,7 +30,8 @@ class BlendMaskStage(BaseStage):
) -> StageResult: ) -> StageResult:
logger.info("blending image using mask") logger.info("blending image using mask")
mult_mask = Image.new("RGBA", stage_mask.size, color="black") # TODO: does this need an alpha channel?
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
mult_mask.alpha_composite(stage_mask) mult_mask.alpha_composite(stage_mask)
mult_mask = mult_mask.convert("L") mult_mask = mult_mask.convert("L")

View File

@ -52,4 +52,16 @@ class StageResult:
if self.images is not None: if self.images is not None:
return self.images return self.images
return [Image.fromarray(np.uint8(i), "RGB") for i in self.arrays] return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
def shape_mode(arr: np.ndarray) -> str:
if len(arr.shape) != 3:
raise ValueError("unknown array format")
if arr.shape[-1] == 3:
return "RGB"
elif arr.shape[-1] == 4:
return "RGBA"
raise ValueError("unknown image format")

View File

@ -151,7 +151,9 @@ def blend_tiles(
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap "adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
) )
scaled_size = (height * scale, width * scale, 3) channels = max([4 if tile_image.mode == "RGBA" else 3 for _left, _top, tile_image in tiles])
scaled_size = (height * scale, width * scale, channels)
count = np.zeros(scaled_size) count = np.zeros(scaled_size)
value = np.zeros(scaled_size) value = np.zeros(scaled_size)
@ -221,7 +223,7 @@ def blend_tiles(
margin_left : equalized.shape[1] + margin_right, margin_left : equalized.shape[1] + margin_right,
np.newaxis, np.newaxis,
], ],
3, channels,
axis=2, axis=2,
) )

View File

@ -8,7 +8,7 @@ def mask_filter_none(
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
noise = Image.new("RGB", (width, height), fill) noise = Image.new(mask.mode, (width, height), fill)
noise.paste(mask, origin) noise.paste(mask, origin)
return noise return noise

View File

@ -17,21 +17,21 @@ def noise_source_fill_edge(
""" """
width, height = dims width, height = dims
noise = Image.new("RGB", (width, height), fill) noise = Image.new(source.mode, (width, height), fill)
noise.paste(source, origin) noise.paste(source, origin)
return noise return noise
def noise_source_fill_mask( def noise_source_fill_mask(
_source: Image.Image, dims: Point, _origin: Point, fill="white", **kw source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
) -> Image.Image: ) -> Image.Image:
""" """
Fill the whole canvas, no source or noise. Fill the whole canvas, no source or noise.
""" """
width, height = dims width, height = dims
noise = Image.new("RGB", (width, height), fill) noise = Image.new(source.mode, (width, height), fill)
return noise return noise
@ -52,7 +52,7 @@ def noise_source_gaussian(
def noise_source_uniform( def noise_source_uniform(
_source: Image.Image, dims: Point, _origin: Point, **kw source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -61,6 +61,7 @@ def noise_source_uniform(
noise_g = random.uniform(0, 256, size=size) noise_g = random.uniform(0, 256, size=size)
noise_b = random.uniform(0, 256, size=size) noise_b = random.uniform(0, 256, size=size)
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
@ -68,11 +69,11 @@ def noise_source_uniform(
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
return noise return noise.convert(source.mode)
def noise_source_normal( def noise_source_normal(
_source: Image.Image, dims: Point, _origin: Point, **kw source: Image.Image, dims: Point, _origin: Point, **kw
) -> Image.Image: ) -> Image.Image:
width, height = dims width, height = dims
size = width * height size = width * height
@ -81,6 +82,7 @@ def noise_source_normal(
noise_g = random.normal(128, 32, size=size) noise_g = random.normal(128, 32, size=size)
noise_b = random.normal(128, 32, size=size) noise_b = random.normal(128, 32, size=size)
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
@ -88,7 +90,7 @@ def noise_source_normal(
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i]))) noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
return noise return noise.convert(source.mode)
def noise_source_histogram( def noise_source_histogram(
@ -112,6 +114,7 @@ def noise_source_histogram(
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size 256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
) )
# needs to be RGB for pixel manipulation
noise = Image.new("RGB", (width, height)) noise = Image.new("RGB", (width, height))
for x in range(width): for x in range(width):
@ -119,4 +122,4 @@ def noise_source_histogram(
i = get_pixel_index(x, y, width) i = get_pixel_index(x, y, width)
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i])) noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
return noise return noise.convert(source.mode)

View File

@ -20,7 +20,7 @@ def expand_image(
size = tuple(size) size = tuple(size)
origin = (expand.left, expand.top) origin = (expand.left, expand.top)
full_source = Image.new("RGB", size, fill) full_source = Image.new(source.mode, size, fill)
full_source.paste(source, origin) full_source.paste(source, origin)
# new mask pixels need to be filled with white so they will be replaced # new mask pixels need to be filled with white so they will be replaced