1
0
Fork 0
onnx-web/api/onnx_web/chain/base.py

126 lines
3.5 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from datetime import timedelta
from logging import getLogger
2023-01-28 20:56:06 +00:00
from time import monotonic
from typing import Any, List, Optional, Protocol, Tuple
2023-02-05 13:53:26 +00:00
from PIL import Image
from ..device_pool import JobContext
from ..output import save_image
from ..params import ImageParams, StageParams
from ..utils import ServerContext, is_debug
from .utils import process_tile_grid
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
class StageCallback(Protocol):
def __call__(
self,
job: JobContext,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainPipeline:
2023-02-05 13:53:26 +00:00
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
2023-02-05 13:53:26 +00:00
"""
def __init__(
self,
stages: List[PipelineStage] = [],
):
2023-02-05 13:53:26 +00:00
"""
Create a new pipeline that will run the given stages.
2023-02-05 13:53:26 +00:00
"""
self.stages = list(stages)
def append(self, stage: PipelineStage):
2023-02-05 13:53:26 +00:00
"""
Append an additional stage to this pipeline.
2023-02-05 13:53:26 +00:00
"""
self.stages.append(stage)
2023-02-05 13:53:26 +00:00
def __call__(
self,
job: JobContext,
server: ServerContext,
params: ImageParams,
source: Image.Image,
**pipeline_kwargs
) -> Image.Image:
"""
TODO: handle List[Image] outputs
2023-02-05 13:53:26 +00:00
"""
2023-01-28 20:56:06 +00:00
start = monotonic()
2023-02-05 13:53:26 +00:00
logger.info(
"running pipeline on source image with dimensions %sx%s",
source.width,
source.height,
)
image = source
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__name__
kwargs = stage_kwargs or {}
2023-01-28 20:56:06 +00:00
kwargs = {**pipeline_kwargs, **kwargs}
2023-02-05 13:53:26 +00:00
logger.info(
"running stage %s on image with dimensions %sx%s, %s",
name,
image.width,
image.height,
kwargs.keys(),
)
if (
image.width > stage_params.tile_size
or image.height > stage_params.tile_size
):
logger.info(
"image larger than tile size of %s, tiling stage",
stage_params.tile_size,
)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
2023-02-05 13:53:26 +00:00
tile = stage_pipe(job, server, stage_params, params, tile, **kwargs)
if is_debug():
2023-02-05 13:53:26 +00:00
save_image(server, "last-tile.png", tile)
return tile
image = process_tile_grid(
2023-02-05 13:53:26 +00:00
image, stage_params.tile_size, stage_params.outscale, [stage_tile]
)
else:
2023-02-05 13:53:26 +00:00
logger.info("image within tile size, running stage")
image = stage_pipe(job, server, stage_params, params, image, **kwargs)
2023-02-05 13:53:26 +00:00
logger.info(
"finished stage %s, result size: %sx%s", name, image.width, image.height
)
if is_debug():
2023-02-05 13:53:26 +00:00
save_image(server, "last-stage.png", image)
2023-01-28 20:56:06 +00:00
end = monotonic()
2023-01-28 23:09:19 +00:00
duration = timedelta(seconds=(end - start))
2023-02-05 13:53:26 +00:00
logger.info(
"finished pipeline in %s, result size: %sx%s",
duration,
image.width,
image.height,
)
return image