1
0
Fork 0
onnx-web/api/onnx_web/chain/blend_mask.py

55 lines
1.7 KiB
Python
Raw Permalink Normal View History

from logging import getLogger
from typing import Optional, Tuple
from PIL import Image
2023-02-27 02:09:42 +00:00
from ..output import save_image
from ..params import ImageParams, StageParams
2023-02-26 05:49:39 +00:00
from ..server import ServerContext
2023-02-19 02:28:21 +00:00
from ..utils import is_debug
2023-02-26 20:15:30 +00:00
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class BlendMaskStage(BaseStage):
def run(
self,
_worker: WorkerContext,
server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
dims: Tuple[int, int, int],
stage_source: Optional[Image.Image] = None,
stage_mask: Optional[Image.Image] = None,
2023-12-21 05:33:13 +00:00
tile_mask: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("blending image using mask")
2023-12-21 05:33:13 +00:00
mask_source = tile_mask or stage_mask
mult_mask = Image.new(mask_source.mode, mask_source.size, color="black")
mult_mask = Image.alpha_composite(mult_mask, mask_source)
mult_mask = mult_mask.convert("L")
left, top, tile = dims
2023-12-21 05:02:58 +00:00
stage_source_tile = stage_source.crop((left, top, left + tile, top + tile))
if is_debug():
2023-12-21 05:33:13 +00:00
save_image(server, "last-mask.png", mask_source)
save_image(server, "last-mult-mask.png", mult_mask)
save_image(server, "last-stage-source.png", stage_source_tile)
2023-12-15 02:12:39 +00:00
return StageResult.from_images(
[
Image.composite(stage_source_tile, source, mult_mask)
2024-01-06 02:11:58 +00:00
for source in sources.as_images()
],
metadata=sources.metadata,
2023-11-19 00:13:13 +00:00
)