diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 326ba162..bda9b1ab 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Optional +from typing import Optional, Tuple from PIL import Image @@ -23,6 +23,7 @@ class BlendMaskStage(BaseStage): _params: ImageParams, sources: StageResult, *, + dims: Tuple[int, int, int], stage_source: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None, _callback: Optional[ProgressCallback] = None, @@ -34,13 +35,17 @@ class BlendMaskStage(BaseStage): mult_mask = Image.alpha_composite(mult_mask, stage_mask) mult_mask = mult_mask.convert("L") + top, left, tile = dims + stage_source_tile = stage_source.crop((top, left, tile, tile)) + if is_debug(): save_image(server, "last-mask.png", stage_mask) save_image(server, "last-mult-mask.png", mult_mask) + save_image(server, "last-stage-source.png", stage_source_tile) return StageResult.from_images( [ - Image.composite(stage_source, source, mult_mask) + Image.composite(stage_source_tile, source, mult_mask) for source in sources.as_image() ] )