1
0
Fork 0
onnx-web/api/onnx_web/chain/base.py

211 lines
5.9 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from datetime import timedelta
from logging import getLogger
2023-01-28 20:56:06 +00:00
from time import monotonic
from typing import Any, List, Optional, Protocol, Tuple
2023-02-05 13:53:26 +00:00
from PIL import Image
from ..output import save_image
from ..params import ImageParams, StageParams
2023-02-26 05:49:39 +00:00
from ..server import ServerContext
2023-02-19 02:28:21 +00:00
from ..utils import is_debug
2023-02-26 20:15:30 +00:00
from ..worker import ProgressCallback, WorkerContext
from .utils import process_tile_order
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
class StageCallback(Protocol):
"""
Definition for a stage job function.
"""
def __call__(
self,
2023-02-26 05:49:39 +00:00
job: WorkerContext,
server: ServerContext,
stage: StageParams,
params: ImageParams,
source: Image.Image,
**kwargs: Any
) -> Image.Image:
"""
Run this stage against a source image.
"""
pass
PipelineStage = Tuple[StageCallback, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
2023-02-05 13:53:26 +00:00
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
2023-02-05 13:53:26 +00:00
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
2023-02-05 13:53:26 +00:00
"""
Create a new pipeline that will run the given stages.
2023-02-05 13:53:26 +00:00
"""
self.stages = list(stages or [])
def append(self, stage: PipelineStage):
2023-02-05 13:53:26 +00:00
"""
2023-07-01 02:42:24 +00:00
DEPRECATED: use `stage` instead
Append an additional stage to this pipeline.
2023-02-05 13:53:26 +00:00
"""
if stage is not None:
self.stages.append(stage)
2023-07-01 02:42:24 +00:00
def run(
self,
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> Image.Image:
"""
TODO: handle List[Image] inputs and outputs
"""
return self(job, server, params, source=source, callback=callback, **kwargs)
def stage(self, callback: StageCallback, params: StageParams, **kwargs):
self.stages.append((callback, params, kwargs))
return self
2023-02-05 13:53:26 +00:00
def __call__(
self,
2023-02-26 05:49:39 +00:00
job: WorkerContext,
2023-02-05 13:53:26 +00:00
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image] = None,
callback: Optional[ProgressCallback] = None,
2023-02-05 13:53:26 +00:00
**pipeline_kwargs
) -> Image.Image:
"""
2023-07-01 02:42:24 +00:00
DEPRECATED: use `run` instead
2023-02-05 13:53:26 +00:00
"""
if callback is not None:
callback = ChainProgress.from_progress(callback)
2023-01-28 20:56:06 +00:00
start = monotonic()
image = source
if source is not None:
logger.info(
"running pipeline on source image with dimensions %sx%s",
source.width,
source.height,
)
else:
logger.info("running pipeline without source image")
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__name__
kwargs = stage_kwargs or {}
2023-01-28 20:56:06 +00:00
kwargs = {**pipeline_kwargs, **kwargs}
if image is not None:
logger.debug(
"running stage %s on source image with dimensions %sx%s, %s",
name,
image.width,
image.height,
kwargs.keys(),
)
else:
logger.debug(
"running stage %s without source image, %s", name, kwargs.keys()
)
2023-02-05 13:53:26 +00:00
if image is not None and (
2023-02-05 13:53:26 +00:00
image.width > stage_params.tile_size
or image.height > stage_params.tile_size
):
logger.info(
"image larger than tile size of %s, tiling stage",
stage_params.tile_size,
)
def stage_tile(tile: Image.Image, _dims) -> Image.Image:
2023-02-12 18:33:36 +00:00
tile = stage_pipe(
job,
server,
stage_params,
params,
tile,
callback=callback,
**kwargs
)
if is_debug():
2023-02-05 13:53:26 +00:00
save_image(server, "last-tile.png", tile)
return tile
image = process_tile_order(
stage_params.tile_order,
image,
stage_params.tile_size,
stage_params.outscale,
[stage_tile],
2023-02-05 13:53:26 +00:00
)
else:
2023-03-17 03:29:07 +00:00
logger.debug("image within tile size, running stage")
2023-02-12 18:33:36 +00:00
image = stage_pipe(
job,
server,
stage_params,
params,
image,
callback=callback,
**kwargs
)
2023-03-17 03:29:07 +00:00
logger.debug(
2023-02-05 13:53:26 +00:00
"finished stage %s, result size: %sx%s", name, image.width, image.height
)
if is_debug():
2023-02-05 13:53:26 +00:00
save_image(server, "last-stage.png", image)
2023-01-28 20:56:06 +00:00
end = monotonic()
2023-01-28 23:09:19 +00:00
duration = timedelta(seconds=(end - start))
2023-02-05 13:53:26 +00:00
logger.info(
"finished pipeline in %s, result size: %sx%s",
duration,
image.width,
image.height,
)
return image