diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 4486bbf6..326ba162 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -30,17 +30,16 @@ class BlendMaskStage(BaseStage): ) -> StageResult: logger.info("blending image using mask") - # TODO: does this need an alpha channel? mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black") - mult_mask.alpha_composite(stage_mask) + mult_mask = Image.alpha_composite(mult_mask, stage_mask) mult_mask = mult_mask.convert("L") if is_debug(): save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mult-mask.png", mult_mask) - return StageResult( - images=[ + return StageResult.from_images( + [ Image.composite(stage_source, source, mult_mask) for source in sources.as_image() ]