2023-01-28 23:09:19 +00:00
|
|
|
from datetime import timedelta
|
|
|
|
from logging import getLogger
|
2023-01-28 20:56:06 +00:00
|
|
|
from time import monotonic
|
2023-07-04 17:09:46 +00:00
|
|
|
from typing import Any, List, Optional, Tuple
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
from ..output import save_image
|
|
|
|
from ..params import ImageParams, StageParams
|
2023-02-26 05:49:39 +00:00
|
|
|
from ..server import ServerContext
|
2023-02-19 02:28:21 +00:00
|
|
|
from ..utils import is_debug
|
2023-02-26 20:15:30 +00:00
|
|
|
from ..worker import ProgressCallback, WorkerContext
|
2023-07-01 12:10:53 +00:00
|
|
|
from .stage import BaseStage
|
2023-07-02 23:14:52 +00:00
|
|
|
from .tile import needs_tile, process_tile_order
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-07-01 12:10:53 +00:00
|
|
|
PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
2023-01-27 23:08:36 +00:00
|
|
|
|
|
|
|
|
2023-02-12 18:17:36 +00:00
|
|
|
class ChainProgress:
|
|
|
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
|
|
|
self.parent = parent
|
|
|
|
self.step = start
|
|
|
|
self.total = 0
|
|
|
|
|
|
|
|
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
|
|
|
if step < self.step:
|
|
|
|
# accumulate on resets
|
|
|
|
self.total += self.step
|
|
|
|
|
|
|
|
self.step = step
|
|
|
|
self.parent(self.get_total(), timestep, latents)
|
|
|
|
|
|
|
|
def get_total(self) -> int:
|
|
|
|
return self.step + self.total
|
|
|
|
|
2023-02-12 19:16:17 +00:00
|
|
|
@classmethod
|
|
|
|
def from_progress(cls, parent: ProgressCallback):
|
|
|
|
start = parent.step if hasattr(parent, "step") else 0
|
|
|
|
return ChainProgress(parent, start=start)
|
|
|
|
|
2023-02-12 18:17:36 +00:00
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
class ChainPipeline:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-01-27 23:08:36 +00:00
|
|
|
Run many stages in series, passing the image results from each to the next, and processing
|
|
|
|
tiles as needed.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-01-27 23:08:36 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-03-01 03:44:52 +00:00
|
|
|
stages: Optional[List[PipelineStage]] = None,
|
2023-01-27 23:08:36 +00:00
|
|
|
):
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-01-27 23:08:36 +00:00
|
|
|
Create a new pipeline that will run the given stages.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-19 13:41:16 +00:00
|
|
|
self.stages = list(stages or [])
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-07-04 15:20:28 +00:00
|
|
|
def append(self, stage: Optional[PipelineStage]):
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-01-27 23:08:36 +00:00
|
|
|
Append an additional stage to this pipeline.
|
2023-07-04 17:17:55 +00:00
|
|
|
|
|
|
|
This requires an already-assembled `PipelineStage`. Use `ChainPipeline.stage` if you want the pipeline to
|
|
|
|
assemble the stage from loose arguments.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-18 16:59:59 +00:00
|
|
|
if stage is not None:
|
|
|
|
self.stages.append(stage)
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-07-01 02:42:24 +00:00
|
|
|
def run(
|
|
|
|
self,
|
|
|
|
job: WorkerContext,
|
|
|
|
server: ServerContext,
|
|
|
|
params: ImageParams,
|
2023-07-04 18:56:02 +00:00
|
|
|
sources: List[Image.Image],
|
2023-07-01 02:42:24 +00:00
|
|
|
callback: Optional[ProgressCallback],
|
|
|
|
**kwargs
|
2023-07-04 18:56:02 +00:00
|
|
|
) -> List[Image.Image]:
|
|
|
|
return self(job, server, params, sources=sources, callback=callback, **kwargs)
|
2023-07-01 02:42:24 +00:00
|
|
|
|
2023-07-04 15:20:28 +00:00
|
|
|
def stage(self, callback: BaseStage, params: StageParams, **kwargs):
|
2023-07-01 02:42:24 +00:00
|
|
|
self.stages.append((callback, params, kwargs))
|
|
|
|
return self
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
def __call__(
|
|
|
|
self,
|
2023-02-26 05:49:39 +00:00
|
|
|
job: WorkerContext,
|
2023-02-05 13:53:26 +00:00
|
|
|
server: ServerContext,
|
|
|
|
params: ImageParams,
|
2023-07-04 18:47:31 +00:00
|
|
|
sources: List[Image.Image],
|
2023-03-01 03:44:52 +00:00
|
|
|
callback: Optional[ProgressCallback] = None,
|
2023-02-05 13:53:26 +00:00
|
|
|
**pipeline_kwargs
|
2023-07-04 18:29:58 +00:00
|
|
|
) -> List[Image.Image]:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-07-01 02:42:24 +00:00
|
|
|
DEPRECATED: use `run` instead
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-12 18:17:36 +00:00
|
|
|
if callback is not None:
|
2023-02-12 19:16:17 +00:00
|
|
|
callback = ChainProgress.from_progress(callback)
|
2023-02-12 18:17:36 +00:00
|
|
|
|
2023-01-28 20:56:06 +00:00
|
|
|
start = monotonic()
|
2023-07-04 18:29:58 +00:00
|
|
|
|
2023-07-04 18:47:31 +00:00
|
|
|
if len(sources) > 0:
|
2023-06-30 04:36:45 +00:00
|
|
|
logger.info(
|
2023-07-04 18:47:31 +00:00
|
|
|
"running pipeline on %s source images",
|
|
|
|
len(sources),
|
2023-06-30 04:36:45 +00:00
|
|
|
)
|
|
|
|
else:
|
2023-07-09 00:48:40 +00:00
|
|
|
sources = [None]
|
2023-07-04 18:47:31 +00:00
|
|
|
logger.info("running pipeline without source images")
|
2023-06-30 04:36:45 +00:00
|
|
|
|
2023-07-04 18:47:31 +00:00
|
|
|
stage_sources = sources
|
2023-01-28 04:48:06 +00:00
|
|
|
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
2023-07-01 12:10:53 +00:00
|
|
|
name = stage_params.name or stage_pipe.__class__.__name__
|
2023-01-28 04:48:06 +00:00
|
|
|
kwargs = stage_kwargs or {}
|
2023-01-28 20:56:06 +00:00
|
|
|
kwargs = {**pipeline_kwargs, **kwargs}
|
2023-07-09 00:48:40 +00:00
|
|
|
logger.debug(
|
2023-07-09 18:31:11 +00:00
|
|
|
"running stage %s with %s source images, parameters: %s",
|
2023-07-09 00:48:40 +00:00
|
|
|
name,
|
2023-07-09 18:31:11 +00:00
|
|
|
len(stage_sources) - stage_sources.count(None),
|
2023-07-09 00:48:40 +00:00
|
|
|
kwargs.keys(),
|
|
|
|
)
|
2023-02-05 13:53:26 +00:00
|
|
|
|
2023-07-04 18:47:31 +00:00
|
|
|
# the stage must be split and tiled if any image is larger than the selected/max tile size
|
|
|
|
must_tile = any(
|
|
|
|
[
|
|
|
|
needs_tile(
|
|
|
|
stage_pipe.max_tile,
|
|
|
|
stage_params.tile_size,
|
|
|
|
size=kwargs.get("size", None),
|
|
|
|
source=source,
|
|
|
|
)
|
|
|
|
for source in stage_sources
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
2023-07-04 21:22:19 +00:00
|
|
|
tile = stage_params.tile_size
|
|
|
|
if stage_pipe.max_tile > 0:
|
|
|
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
|
|
|
|
2023-07-04 18:47:31 +00:00
|
|
|
if must_tile:
|
|
|
|
stage_outputs = []
|
|
|
|
for source in stage_sources:
|
|
|
|
logger.info(
|
|
|
|
"image larger than tile size of %s, tiling stage",
|
2023-02-12 18:33:36 +00:00
|
|
|
tile,
|
|
|
|
)
|
2023-01-28 18:42:02 +00:00
|
|
|
|
2023-07-09 05:02:27 +00:00
|
|
|
def stage_tile(
|
2023-07-10 04:49:34 +00:00
|
|
|
source_tile: Image.Image,
|
|
|
|
tile_mask: Image.Image,
|
|
|
|
dims: Tuple[int, int, int],
|
2023-07-09 05:02:27 +00:00
|
|
|
) -> Image.Image:
|
2023-07-15 14:12:37 +00:00
|
|
|
for i in range(3):
|
|
|
|
try:
|
|
|
|
output_tile = stage_pipe.run(
|
|
|
|
job,
|
|
|
|
server,
|
|
|
|
stage_params,
|
|
|
|
params,
|
|
|
|
[source_tile],
|
|
|
|
tile_mask=tile_mask,
|
|
|
|
callback=callback,
|
|
|
|
dims=dims,
|
|
|
|
**kwargs,
|
|
|
|
)[0]
|
|
|
|
|
|
|
|
if is_debug():
|
|
|
|
save_image(server, "last-tile.png", output_tile)
|
|
|
|
|
|
|
|
return output_tile
|
|
|
|
except:
|
|
|
|
logger.exception("error while running stage pipeline for tile, retry %s of 3", i)
|
2023-07-04 18:47:31 +00:00
|
|
|
|
|
|
|
output = process_tile_order(
|
|
|
|
stage_params.tile_order,
|
|
|
|
source,
|
|
|
|
tile,
|
|
|
|
stage_params.outscale,
|
2023-07-04 21:40:02 +00:00
|
|
|
[stage_tile],
|
2023-07-04 18:47:31 +00:00
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
stage_outputs.append(output)
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-07-04 18:47:31 +00:00
|
|
|
stage_sources = stage_outputs
|
2023-01-27 23:08:36 +00:00
|
|
|
else:
|
2023-07-04 17:17:55 +00:00
|
|
|
logger.debug("image within tile size of %s, running stage", tile)
|
2023-07-15 14:12:37 +00:00
|
|
|
for i in range(3):
|
|
|
|
try:
|
|
|
|
stage_sources = stage_pipe.run(
|
|
|
|
job,
|
|
|
|
server,
|
|
|
|
stage_params,
|
|
|
|
params,
|
|
|
|
stage_sources,
|
|
|
|
callback=callback,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
break
|
|
|
|
except:
|
|
|
|
logger.exception("error while running stage pipeline, retry %s of 3", i)
|
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-03-17 03:29:07 +00:00
|
|
|
logger.debug(
|
2023-07-04 18:47:31 +00:00
|
|
|
"finished stage %s with %s results",
|
2023-07-04 17:17:55 +00:00
|
|
|
name,
|
2023-07-04 18:47:31 +00:00
|
|
|
len(stage_sources),
|
2023-02-05 13:53:26 +00:00
|
|
|
)
|
2023-01-27 23:08:36 +00:00
|
|
|
|
2023-01-29 05:08:56 +00:00
|
|
|
if is_debug():
|
2023-07-04 18:47:31 +00:00
|
|
|
save_image(server, "last-stage.png", stage_sources[0])
|
2023-01-29 05:08:56 +00:00
|
|
|
|
2023-01-28 20:56:06 +00:00
|
|
|
end = monotonic()
|
2023-01-28 23:09:19 +00:00
|
|
|
duration = timedelta(seconds=(end - start))
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.info(
|
2023-07-04 18:47:31 +00:00
|
|
|
"finished pipeline in %s with %s results",
|
2023-02-05 13:53:26 +00:00
|
|
|
duration,
|
2023-07-04 22:14:25 +00:00
|
|
|
len(stage_sources),
|
2023-02-05 13:53:26 +00:00
|
|
|
)
|
2023-07-04 22:14:25 +00:00
|
|
|
return stage_sources
|