1
0
Fork 0
onnx-web/api/onnx_web/chain/base.py

112 lines
3.3 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from datetime import timedelta
from logging import getLogger
from PIL import Image
2023-01-28 20:56:06 +00:00
from time import monotonic
from typing import Any, List, Optional, Protocol, Tuple
2023-02-04 20:27:27 +00:00
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
StageParams,
)
from ..output import (
save_image,
)
from ..utils import (
is_debug,
ServerContext,
)
from .utils import (
process_tile_grid,
)
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
class StageCallback(Protocol):
def __call__(
self,
ctx: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainPipeline:
'''
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
'''
def __init__(
self,
stages: List[PipelineStage] = [],
):
'''
Create a new pipeline that will run the given stages.
'''
self.stages = list(stages)
def append(self, stage: PipelineStage):
'''
Append an additional stage to this pipeline.
'''
self.stages.append(stage)
2023-02-04 20:27:27 +00:00
def __call__(self, job: JobContext, server: ServerContext, params: ImageParams, source: Image.Image, **pipeline_kwargs) -> Image.Image:
'''
TODO: handle List[Image] outputs
'''
2023-01-28 20:56:06 +00:00
start = monotonic()
2023-01-28 23:09:19 +00:00
logger.info('running pipeline on source image with dimensions %sx%s',
source.width, source.height)
image = source
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__name__
kwargs = stage_kwargs or {}
2023-01-28 20:56:06 +00:00
kwargs = {**pipeline_kwargs, **kwargs}
logger.info('running stage %s on image with dimensions %sx%s, %s',
2023-01-28 23:15:42 +00:00
name, image.width, image.height, kwargs.keys())
if image.width > stage_params.tile_size or image.height > stage_params.tile_size:
logger.info('image larger than tile size of %s, tiling stage',
2023-01-28 23:09:19 +00:00
stage_params.tile_size)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
2023-02-04 20:27:27 +00:00
tile = stage_pipe(server, stage_params, params, tile,
**kwargs)
if is_debug():
2023-02-04 20:27:27 +00:00
save_image(server, 'last-tile.png', tile)
return tile
image = process_tile_grid(
image, stage_params.tile_size, stage_params.outscale, [stage_tile])
else:
logger.info('image within tile size, running stage')
2023-02-04 20:27:27 +00:00
image = stage_pipe(server, stage_params, params, image,
**kwargs)
2023-01-28 23:09:19 +00:00
logger.info('finished stage %s, result size: %sx%s',
name, image.width, image.height)
if is_debug():
2023-02-04 20:27:27 +00:00
save_image(server, 'last-stage.png', image)
2023-01-28 20:56:06 +00:00
end = monotonic()
2023-01-28 23:09:19 +00:00
duration = timedelta(seconds=(end - start))
logger.info('finished pipeline in %s, result size: %sx%s',
duration, image.width, image.height)
return image