1
0
Fork 0
onnx-web/api/onnx_web/chain/blend_grid.py

63 lines
1.6 KiB
Python
Raw Normal View History

from logging import getLogger
2023-11-19 00:13:13 +00:00
from typing import Optional
from PIL import Image
from ..params import ImageParams, Size, SizeChart, StageParams
from ..server import ServerContext
from ..worker import ProgressCallback, WorkerContext
from .base import BaseStage
from .result import StageResult
logger = getLogger(__name__)
class BlendGridStage(BaseStage):
max_tile = SizeChart.max
def run(
self,
_worker: WorkerContext,
_server: ServerContext,
_stage: StageParams,
_params: ImageParams,
sources: StageResult,
*,
height: int,
width: int,
2023-09-16 00:16:47 +00:00
# rows: Optional[List[str]] = None,
# columns: Optional[List[str]] = None,
# title: Optional[str] = None,
order: Optional[int] = None,
stage_source: Optional[Image.Image] = None,
2023-09-16 00:16:47 +00:00
callback: Optional[ProgressCallback] = None,
**kwargs,
) -> StageResult:
logger.info("combining source images using grid layout")
images = sources.as_image()
ref_image = images[0]
size = Size(*ref_image.size)
output = Image.new(ref_image.mode, (size.width * width, size.height * height))
# TODO: labels
2023-09-11 04:15:01 +00:00
if order is None:
order = range(len(images))
2023-09-11 04:15:01 +00:00
for i in range(len(order)):
x = i % width
y = i // width
n = order[i]
output.paste(images[n], (x * size.width, y * size.height))
return StageResult(images=[*images, output])
def outputs(
2023-09-13 00:17:03 +00:00
self,
2023-09-16 00:16:47 +00:00
_params: ImageParams,
2023-09-13 00:17:03 +00:00
sources: int,
) -> int:
2023-09-13 00:17:03 +00:00
return sources + 1