From 1fbee0ae52b4749fc2f55af315a961818aa6e059 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 20 Dec 2023 23:33:13 -0600 Subject: [PATCH] fix(api): tile stage masks --- api/onnx_web/chain/blend_mask.py | 8 +++++--- api/onnx_web/chain/tile.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/chain/blend_mask.py b/api/onnx_web/chain/blend_mask.py index 52f038e7..f85fad1d 100644 --- a/api/onnx_web/chain/blend_mask.py +++ b/api/onnx_web/chain/blend_mask.py @@ -26,20 +26,22 @@ class BlendMaskStage(BaseStage): dims: Tuple[int, int, int], stage_source: Optional[Image.Image] = None, stage_mask: Optional[Image.Image] = None, + tile_mask: Optional[Image.Image] = None, _callback: Optional[ProgressCallback] = None, **kwargs, ) -> StageResult: logger.info("blending image using mask") - mult_mask = Image.new(stage_mask.mode, stage_mask.size, color="black") - mult_mask = Image.alpha_composite(mult_mask, stage_mask) + mask_source = tile_mask or stage_mask + mult_mask = Image.new(mask_source.mode, mask_source.size, color="black") + mult_mask = Image.alpha_composite(mult_mask, mask_source) mult_mask = mult_mask.convert("L") top, left, tile = dims stage_source_tile = stage_source.crop((left, top, left + tile, top + tile)) if is_debug(): - save_image(server, "last-mask.png", stage_mask) + save_image(server, "last-mask.png", mask_source) save_image(server, "last-mult-mask.png", mult_mask) save_image(server, "last-stage-source.png", stage_source_tile) diff --git a/api/onnx_web/chain/tile.py b/api/onnx_web/chain/tile.py index 01ea9f2b..5716fd22 100644 --- a/api/onnx_web/chain/tile.py +++ b/api/onnx_web/chain/tile.py @@ -258,7 +258,7 @@ def process_tile_stack( sources = stack.as_image() width, height = kwargs.get("size", sources[0].size if len(sources) > 0 else None) - mask = kwargs.get("mask", None) + mask = kwargs.get("mask", kwargs.get("stage_mask", None)) noise_source = kwargs.get("noise_source", noise_source_histogram) fill_color = kwargs.get("fill_color", None) if not mask: