1
0
Fork 0
onnx-web/api/onnx_web/chain/base.py

214 lines
6.6 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from datetime import timedelta
from logging import getLogger
2023-01-28 20:56:06 +00:00
from time import monotonic
2023-07-04 17:09:46 +00:00
from typing import Any, List, Optional, Tuple
2023-02-05 13:53:26 +00:00
from PIL import Image
from ..output import save_image
from ..params import ImageParams, StageParams
2023-02-26 05:49:39 +00:00
from ..server import ServerContext
2023-02-19 02:28:21 +00:00
from ..utils import is_debug
2023-02-26 20:15:30 +00:00
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .tile import needs_tile, process_tile_order
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
2023-02-05 13:53:26 +00:00
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
2023-02-05 13:53:26 +00:00
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
2023-02-05 13:53:26 +00:00
"""
Create a new pipeline that will run the given stages.
2023-02-05 13:53:26 +00:00
"""
self.stages = list(stages or [])
def append(self, stage: Optional[PipelineStage]):
2023-02-05 13:53:26 +00:00
"""
Append an additional stage to this pipeline.
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
assemble the stage from loose arguments.
2023-02-05 13:53:26 +00:00
"""
if stage is not None:
self.stages.append(stage)
2023-07-01 02:42:24 +00:00
def run(
self,
job: WorkerContext,
server: ServerContext,
params: ImageParams,
source: Optional[Image.Image],
callback: Optional[ProgressCallback],
**kwargs
) -> Image.Image:
"""
TODO: handle List[Image] inputs and outputs
"""
return self(job, server, params, source=source, callback=callback, **kwargs)
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
2023-07-01 02:42:24 +00:00
self.stages.append((callback, params, kwargs))
return self
2023-02-05 13:53:26 +00:00
def __call__(
self,
2023-02-26 05:49:39 +00:00
job: WorkerContext,
2023-02-05 13:53:26 +00:00
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback] = None,
2023-02-05 13:53:26 +00:00
**pipeline_kwargs
) -> List[Image.Image]:
2023-02-05 13:53:26 +00:00
"""
2023-07-01 02:42:24 +00:00
DEPRECATED: use `run` instead
2023-02-05 13:53:26 +00:00
"""
if callback is not None:
callback = ChainProgress.from_progress(callback)
2023-01-28 20:56:06 +00:00
start = monotonic()
if len(sources) > 0:
logger.info(
"running pipeline on %s source images",
len(sources),
)
else:
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
2023-01-28 20:56:06 +00:00
kwargs = {**pipeline_kwargs, **kwargs}
if len(stage_sources) > 0:
logger.debug(
"running stage %s with %s source images, parameters: %s",
name,
len(stage_sources),
kwargs.keys(),
)
else:
logger.debug(
"running stage %s without source images, parameters: %s",
name,
kwargs.keys(),
)
2023-02-05 13:53:26 +00:00
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources
]
)
if must_tile:
stage_outputs = []
for source in stage_sources:
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
logger.info(
"image larger than tile size of %s, tiling stage",
2023-02-12 18:33:36 +00:00
tile,
)
def stage_tile(source_tile: Image.Image, _dims) -> Image.Image:
output_tile = stage_pipe.run(
job,
server,
stage_params,
params,
source_tile,
callback=callback,
**kwargs,
)
if is_debug():
save_image(server, "last-tile.png", output_tile)
return output_tile
output = process_tile_order(
stage_params.tile_order,
source,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
stage_outputs.append(output)
stage_sources = stage_outputs
else:
logger.debug("image within tile size of %s, running stage", tile)
stage_sources = stage_pipe.run(
2023-02-12 18:33:36 +00:00
job,
server,
stage_params,
params,
stage_sources,
2023-02-12 18:33:36 +00:00
callback=callback,
2023-07-01 19:39:02 +00:00
**kwargs,
2023-02-12 18:33:36 +00:00
)
2023-03-17 03:29:07 +00:00
logger.debug(
"finished stage %s with %s results",
name,
len(stage_sources),
2023-02-05 13:53:26 +00:00
)
if is_debug():
save_image(server, "last-stage.png", stage_sources[0])
2023-01-28 20:56:06 +00:00
end = monotonic()
2023-01-28 23:09:19 +00:00
duration = timedelta(seconds=(end - start))
2023-02-05 13:53:26 +00:00
logger.info(
"finished pipeline in %s with %s results",
2023-02-05 13:53:26 +00:00
duration,
len(stage_outputs),
2023-02-05 13:53:26 +00:00
)
return stage_outputs