1
0
Fork 0
onnx-web/api/onnx_web/chain/base.py

104 lines
3.1 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from datetime import timedelta
from logging import getLogger
from PIL import Image
from os import path
2023-01-28 20:56:06 +00:00
from time import monotonic
from typing import Any, List, Optional, Protocol, Tuple
from ..params import (
ImageParams,
StageParams,
)
from ..utils import (
is_debug,
ServerContext,
)
from .utils import (
process_tiles,
)
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
class StageCallback(Protocol):
def __call__(
self,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainPipeline:
'''
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
'''
def __init__(
self,
stages: List[PipelineStage] = [],
):
'''
Create a new pipeline that will run the given stages.
'''
self.stages = stages
def append(self, stage: PipelineStage):
'''
Append an additional stage to this pipeline.
'''
self.stages.append(stage)
2023-01-28 20:56:06 +00:00
def __call__(self, ctx: ServerContext, params: ImageParams, source: Image.Image, **pipeline_kwargs) -> Image.Image:
'''
TODO: handle List[Image] outputs
'''
2023-01-28 20:56:06 +00:00
start = monotonic()
2023-01-28 23:09:19 +00:00
logger.info('running pipeline on source image with dimensions %sx%s',
source.width, source.height)
image = source
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__name__
kwargs = stage_kwargs or {}
2023-01-28 20:56:06 +00:00
kwargs = {**pipeline_kwargs, **kwargs}
2023-01-28 23:09:19 +00:00
logger.info('running stage %s on result image with dimensions %sx%s, %s',
name, image.width, image.height, kwargs)
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
2023-01-28 23:09:19 +00:00
logger.info('source image larger than tile size of %s, tiling stage',
stage_params.tile_size)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
tile = stage_pipe(ctx, stage_params, params, tile,
**kwargs)
if is_debug():
tile.save(path.join(ctx.output_path, 'last-tile.png'))
return tile
image = process_tiles(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
2023-01-28 23:09:19 +00:00
logger.info('source image within tile size, running stage')
image = stage_pipe(ctx, stage_params, params, image,
**kwargs)
2023-01-28 23:09:19 +00:00
logger.info('finished stage %s, result size: %sx%s',
name, image.width, image.height)
2023-01-28 20:56:06 +00:00
end = monotonic()
2023-01-28 23:09:19 +00:00
duration = timedelta(seconds=(end - start))
logger.info('finished pipeline in %s, result size: %sx%s',
duration, image.width, image.height)
return image