fix(api): better handling of alpha channels
This commit is contained in:
parent
c134edf4b3
commit
1c3b2f8dfc
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
|
||||
from PIL import Image
|
||||
|
||||
from ..params import ImageParams, SizeChart, StageParams
|
||||
from ..params import ImageParams, Size, SizeChart, StageParams
|
||||
from ..server import ServerContext
|
||||
from ..worker import ProgressCallback, WorkerContext
|
||||
from .base import BaseStage
|
||||
|
@ -36,9 +36,10 @@ class BlendGridStage(BaseStage):
|
|||
logger.info("combining source images using grid layout")
|
||||
|
||||
images = sources.as_image()
|
||||
size = images[0].size
|
||||
ref_image = images[0]
|
||||
size = Size(*ref_image.size)
|
||||
|
||||
output = Image.new("RGB", (size[0] * width, size[1] * height))
|
||||
output = Image.new(ref_image.mode, (size.width * width, size.height * height))
|
||||
|
||||
# TODO: labels
|
||||
if order is None:
|
||||
|
@ -49,7 +50,7 @@ class BlendGridStage(BaseStage):
|
|||
y = i // width
|
||||
|
||||
n = order[i]
|
||||
output.paste(images[n], (x * size[0], y * size[1]))
|
||||
output.paste(images[n], (x * size.width, y * size.height))
|
||||
|
||||
return StageResult(images=[*images, output])
|
||||
|
||||
|
|
|
@ -30,7 +30,8 @@ class BlendMaskStage(BaseStage):
|
|||
) -> StageResult:
|
||||
logger.info("blending image using mask")
|
||||
|
||||
mult_mask = Image.new("RGBA", stage_mask.size, color="black")
|
||||
# TODO: does this need an alpha channel?
|
||||
mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black")
|
||||
mult_mask.alpha_composite(stage_mask)
|
||||
mult_mask = mult_mask.convert("L")
|
||||
|
||||
|
|
|
@ -52,4 +52,16 @@ class StageResult:
|
|||
if self.images is not None:
|
||||
return self.images
|
||||
|
||||
return [Image.fromarray(np.uint8(i), "RGB") for i in self.arrays]
|
||||
return [Image.fromarray(np.uint8(i), shape_mode(i)) for i in self.arrays]
|
||||
|
||||
|
||||
def shape_mode(arr: np.ndarray) -> str:
|
||||
if len(arr.shape) != 3:
|
||||
raise ValueError("unknown array format")
|
||||
|
||||
if arr.shape[-1] == 3:
|
||||
return "RGB"
|
||||
elif arr.shape[-1] == 4:
|
||||
return "RGBA"
|
||||
|
||||
raise ValueError("unknown image format")
|
|
@ -151,7 +151,9 @@ def blend_tiles(
|
|||
"adjusting tile size from %s to %s based on %s overlap", tile, adj_tile, overlap
|
||||
)
|
||||
|
||||
scaled_size = (height * scale, width * scale, 3)
|
||||
channels = max([4 if tile_image.mode == "RGBA" else 3 for _left, _top, tile_image in tiles])
|
||||
scaled_size = (height * scale, width * scale, channels)
|
||||
|
||||
count = np.zeros(scaled_size)
|
||||
value = np.zeros(scaled_size)
|
||||
|
||||
|
@ -221,7 +223,7 @@ def blend_tiles(
|
|||
margin_left : equalized.shape[1] + margin_right,
|
||||
np.newaxis,
|
||||
],
|
||||
3,
|
||||
channels,
|
||||
axis=2,
|
||||
)
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ def mask_filter_none(
|
|||
) -> Image.Image:
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise = Image.new(mask.mode, (width, height), fill)
|
||||
noise.paste(mask, origin)
|
||||
|
||||
return noise
|
||||
|
|
|
@ -17,21 +17,21 @@ def noise_source_fill_edge(
|
|||
"""
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise = Image.new(source.mode, (width, height), fill)
|
||||
noise.paste(source, origin)
|
||||
|
||||
return noise
|
||||
|
||||
|
||||
def noise_source_fill_mask(
|
||||
_source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
|
||||
source: Image.Image, dims: Point, _origin: Point, fill="white", **kw
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Fill the whole canvas, no source or noise.
|
||||
"""
|
||||
width, height = dims
|
||||
|
||||
noise = Image.new("RGB", (width, height), fill)
|
||||
noise = Image.new(source.mode, (width, height), fill)
|
||||
|
||||
return noise
|
||||
|
||||
|
@ -52,7 +52,7 @@ def noise_source_gaussian(
|
|||
|
||||
|
||||
def noise_source_uniform(
|
||||
_source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
@ -61,6 +61,7 @@ def noise_source_uniform(
|
|||
noise_g = random.uniform(0, 256, size=size)
|
||||
noise_b = random.uniform(0, 256, size=size)
|
||||
|
||||
# needs to be RGB for pixel manipulation
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
|
@ -68,11 +69,11 @@ def noise_source_uniform(
|
|||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||
|
||||
return noise
|
||||
return noise.convert(source.mode)
|
||||
|
||||
|
||||
def noise_source_normal(
|
||||
_source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
source: Image.Image, dims: Point, _origin: Point, **kw
|
||||
) -> Image.Image:
|
||||
width, height = dims
|
||||
size = width * height
|
||||
|
@ -81,6 +82,7 @@ def noise_source_normal(
|
|||
noise_g = random.normal(128, 32, size=size)
|
||||
noise_b = random.normal(128, 32, size=size)
|
||||
|
||||
# needs to be RGB for pixel manipulation
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
|
@ -88,7 +90,7 @@ def noise_source_normal(
|
|||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (int(noise_r[i]), int(noise_g[i]), int(noise_b[i])))
|
||||
|
||||
return noise
|
||||
return noise.convert(source.mode)
|
||||
|
||||
|
||||
def noise_source_histogram(
|
||||
|
@ -112,6 +114,7 @@ def noise_source_histogram(
|
|||
256, p=np.divide(np.copy(hist_b), np.sum(hist_b)), size=size
|
||||
)
|
||||
|
||||
# needs to be RGB for pixel manipulation
|
||||
noise = Image.new("RGB", (width, height))
|
||||
|
||||
for x in range(width):
|
||||
|
@ -119,4 +122,4 @@ def noise_source_histogram(
|
|||
i = get_pixel_index(x, y, width)
|
||||
noise.putpixel((x, y), (noise_r[i], noise_g[i], noise_b[i]))
|
||||
|
||||
return noise
|
||||
return noise.convert(source.mode)
|
||||
|
|
|
@ -20,7 +20,7 @@ def expand_image(
|
|||
size = tuple(size)
|
||||
origin = (expand.left, expand.top)
|
||||
|
||||
full_source = Image.new("RGB", size, fill)
|
||||
full_source = Image.new(source.mode, size, fill)
|
||||
full_source.paste(source, origin)
|
||||
|
||||
# new mask pixels need to be filled with white so they will be replaced
|
||||
|
|
Loading…
Reference in New Issue