1
0
Fork 0
onnx-web/api/onnx_web/chain/base.py

264 lines
9.1 KiB
Python
Raw Normal View History

2023-01-28 23:09:19 +00:00
from datetime import timedelta
from logging import getLogger
2023-01-28 20:56:06 +00:00
from time import monotonic
2023-07-04 17:09:46 +00:00
from typing import Any, List, Optional, Tuple
2023-02-05 13:53:26 +00:00
from PIL import Image
2023-07-15 22:05:27 +00:00
from ..errors import RetryException
2023-02-05 13:53:26 +00:00
from ..output import save_image
from ..params import ImageParams, StageParams
2023-02-26 05:49:39 +00:00
from ..server import ServerContext
2023-07-15 16:00:06 +00:00
from ..utils import is_debug, run_gc
2023-02-26 20:15:30 +00:00
from ..worker import ProgressCallback, WorkerContext
from .stage import BaseStage
from .tile import needs_tile, process_tile_order
2023-01-28 23:09:19 +00:00
logger = getLogger(__name__)
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.total = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step:
# accumulate on resets
self.total += self.step
self.step = step
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.total
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
return ChainProgress(parent, start=start)
class ChainPipeline:
2023-02-05 13:53:26 +00:00
"""
Run many stages in series, passing the image results from each to the next, and processing
tiles as needed.
2023-02-05 13:53:26 +00:00
"""
def __init__(
self,
stages: Optional[List[PipelineStage]] = None,
):
2023-02-05 13:53:26 +00:00
"""
Create a new pipeline that will run the given stages.
2023-02-05 13:53:26 +00:00
"""
self.stages = list(stages or [])
def append(self, stage: Optional[PipelineStage]):
2023-02-05 13:53:26 +00:00
"""
Append an additional stage to this pipeline.
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
assemble the stage from loose arguments.
2023-02-05 13:53:26 +00:00
"""
if stage is not None:
self.stages.append(stage)
2023-07-01 02:42:24 +00:00
def run(
self,
worker: WorkerContext,
2023-07-01 02:42:24 +00:00
server: ServerContext,
params: ImageParams,
2023-07-04 18:56:02 +00:00
sources: List[Image.Image],
2023-07-01 02:42:24 +00:00
callback: Optional[ProgressCallback],
**kwargs
2023-07-04 18:56:02 +00:00
) -> List[Image.Image]:
2023-07-16 00:01:06 +00:00
return self(
worker, server, params, sources=sources, callback=callback, **kwargs
)
2023-07-01 02:42:24 +00:00
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
2023-07-01 02:42:24 +00:00
self.stages.append((callback, params, kwargs))
return self
2023-02-05 13:53:26 +00:00
def __call__(
self,
worker: WorkerContext,
2023-02-05 13:53:26 +00:00
server: ServerContext,
params: ImageParams,
sources: List[Image.Image],
callback: Optional[ProgressCallback] = None,
2023-02-05 13:53:26 +00:00
**pipeline_kwargs
) -> List[Image.Image]:
2023-02-05 13:53:26 +00:00
"""
2023-07-01 02:42:24 +00:00
DEPRECATED: use `run` instead
2023-02-05 13:53:26 +00:00
"""
if callback is not None:
callback = ChainProgress.from_progress(callback)
2023-01-28 20:56:06 +00:00
start = monotonic()
if len(sources) > 0:
logger.info(
"running pipeline on %s source images",
len(sources),
)
else:
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
2023-01-28 20:56:06 +00:00
kwargs = {**pipeline_kwargs, **kwargs}
logger.debug(
"running stage %s with %s source images, parameters: %s",
name,
len(stage_sources) - stage_sources.count(None),
kwargs.keys(),
)
2023-02-05 13:53:26 +00:00
2023-09-11 03:21:48 +00:00
per_stage_params = params
if "params" in kwargs:
per_stage_params = kwargs["params"]
kwargs.pop("params")
# the stage must be split and tiled if any image is larger than the selected/max tile size
must_tile = any(
[
needs_tile(
stage_pipe.max_tile,
stage_params.tile_size,
size=kwargs.get("size", None),
source=source,
)
for source in stage_sources
]
)
2023-07-04 21:22:19 +00:00
tile = stage_params.tile_size
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
if must_tile:
stage_outputs = []
for source in stage_sources:
logger.info(
"image larger than tile size of %s, tiling stage",
2023-02-12 18:33:36 +00:00
tile,
)
2023-09-11 12:28:20 +00:00
extra_tiles = []
2023-07-09 05:02:27 +00:00
def stage_tile(
source_tile: Image.Image,
tile_mask: Image.Image,
dims: Tuple[int, int, int],
2023-07-09 05:02:27 +00:00
) -> Image.Image:
for i in range(worker.retries):
2023-07-15 14:12:37 +00:00
try:
output_tile = stage_pipe.run(
worker,
2023-07-15 14:12:37 +00:00
server,
stage_params,
2023-09-11 03:21:48 +00:00
per_stage_params,
2023-07-15 14:12:37 +00:00
[source_tile],
tile_mask=tile_mask,
callback=callback,
dims=dims,
**kwargs,
2023-09-11 12:28:20 +00:00
)
if len(output_tile) > 1:
while len(extra_tiles) < len(output_tile):
extra_tiles.append([])
for tile, layer in zip(output_tile, extra_tiles):
layer.append((tile, dims))
2023-07-15 14:12:37 +00:00
if is_debug():
2023-09-11 12:28:20 +00:00
save_image(server, "last-tile.png", output_tile[0])
2023-07-15 14:12:37 +00:00
2023-09-11 12:28:20 +00:00
return output_tile[0]
2023-07-15 14:26:39 +00:00
except Exception:
logger.exception(
"error while running stage pipeline for tile, retry %s of 3",
i,
)
2023-07-15 17:33:07 +00:00
server.cache.clear()
run_gc([worker.get_device()])
worker.retries = worker.retries - (i + 1)
2023-07-15 22:05:27 +00:00
raise RetryException("exhausted retries on tile")
output = process_tile_order(
stage_params.tile_order,
source,
tile,
stage_params.outscale,
[stage_tile],
**kwargs,
)
2023-09-11 12:28:20 +00:00
stage_outputs.append(output)
2023-09-11 12:28:20 +00:00
if len(extra_tiles) > 1:
for layer in extra_tiles:
layer_output = Image.new("RGB", output.size)
for layer_tile, dims in layer:
layer_output.paste(layer_tile, (dims[0], dims[1]))
2023-09-11 12:28:20 +00:00
stage_outputs.append(layer_output)
stage_sources = stage_outputs
else:
logger.debug("image within tile size of %s, running stage", tile)
for i in range(worker.retries):
2023-07-15 14:12:37 +00:00
try:
2023-07-15 16:20:25 +00:00
stage_outputs = stage_pipe.run(
worker,
2023-07-15 14:12:37 +00:00
server,
stage_params,
2023-09-11 03:21:48 +00:00
per_stage_params,
2023-07-15 14:12:37 +00:00
stage_sources,
callback=callback,
2023-09-11 12:28:20 +00:00
dims=(0, 0, tile),
2023-07-15 14:12:37 +00:00
**kwargs,
)
2023-07-15 16:20:25 +00:00
# doing this on the same line as stage_pipe.run can leave sources as None, which the pipeline
# does not like, so it throws
stage_sources = stage_outputs
2023-07-15 14:12:37 +00:00
break
2023-07-15 14:26:39 +00:00
except Exception:
logger.exception(
"error while running stage pipeline, retry %s of 3", i
)
2023-07-15 17:33:07 +00:00
server.cache.clear()
run_gc([worker.get_device()])
worker.retries = worker.retries - (i + 1)
if worker.retries <= 0:
2023-07-15 22:05:27 +00:00
raise RetryException("exhausted retries on stage")
2023-03-17 03:29:07 +00:00
logger.debug(
"finished stage %s with %s results",
name,
len(stage_sources),
2023-02-05 13:53:26 +00:00
)
if is_debug():
save_image(server, "last-stage.png", stage_sources[0])
2023-01-28 20:56:06 +00:00
end = monotonic()
2023-01-28 23:09:19 +00:00
duration = timedelta(seconds=(end - start))
2023-02-05 13:53:26 +00:00
logger.info(
"finished pipeline in %s with %s results",
2023-02-05 13:53:26 +00:00
duration,
2023-07-04 22:14:25 +00:00
len(stage_sources),
2023-02-05 13:53:26 +00:00
)
2023-07-04 22:14:25 +00:00
return stage_sources