1
0
Fork 0

report stage and tile count in progress

This commit is contained in:
Sean Sube 2024-01-03 21:16:44 -06:00
parent 9d05d9baac
commit 0c504e3f69
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 31 additions and 18 deletions

View File

@ -22,6 +22,13 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress: class ChainProgress:
parent: ProgressCallback
step: int
total: int
stage: int
tile: int
# TODO: total stages and tiles
def __init__(self, parent: ProgressCallback, start=0) -> None: def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent self.parent = parent
self.step = start self.step = start
@ -94,12 +101,8 @@ class ChainPipeline:
return steps return steps
def outputs(self, params: ImageParams, sources: int) -> int: def stages(self) -> int:
outputs = sources return len(self.stages)
for callback, _params, kwargs in self.stages:
outputs = callback.outputs(kwargs.get("params", params), outputs)
return outputs
def __call__( def __call__(
self, self,
@ -110,12 +113,11 @@ class ChainPipeline:
callback: Optional[ProgressCallback] = None, callback: Optional[ProgressCallback] = None,
**pipeline_kwargs, **pipeline_kwargs,
) -> StageResult: ) -> StageResult:
"""
DEPRECATED: use `.run()` instead
"""
if callback is None: if callback is None:
callback = worker.get_progress_callback() callback = worker.get_progress_callback()
else:
# wrap the progress counter in a one that can be reset if needed
if not isinstance(callback, ChainProgress):
callback = ChainProgress.from_progress(callback) callback = ChainProgress.from_progress(callback)
start = monotonic() start = monotonic()
@ -129,7 +131,7 @@ class ChainPipeline:
logger.info("running pipeline without source images") logger.info("running pipeline without source images")
stage_sources = sources stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages: for stage_i, (stage_pipe, stage_params, stage_kwargs) in enumerate(self.stages):
name = stage_params.name or stage_pipe.__class__.__name__ name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {} kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs} kwargs = {**pipeline_kwargs, **kwargs}
@ -139,6 +141,7 @@ class ChainPipeline:
len(stage_sources), len(stage_sources),
kwargs.keys(), kwargs.keys(),
) )
callback.stage = stage_i
per_stage_params = params per_stage_params = params
if "params" in kwargs: if "params" in kwargs:
@ -168,6 +171,7 @@ class ChainPipeline:
tile, tile,
) )
callback.tile = 0
def stage_tile( def stage_tile(
source_tile: List[Image.Image], source_tile: List[Image.Image],
tile_mask: Image.Image, tile_mask: Image.Image,
@ -191,6 +195,8 @@ class ChainPipeline:
for j, image in enumerate(tile_result.as_image()): for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image) save_image(server, f"last-tile-{j}.png", image)
callback.tile = callback.tile + 1
return tile_result return tile_result
except CancelledException as err: except CancelledException as err:
worker.retries = 0 worker.retries = 0

View File

@ -28,6 +28,7 @@ class WorkerContext:
timeout: float timeout: float
retries: int retries: int
initial_retries: int initial_retries: int
callback: Optional[Any]
def __init__( def __init__(
self, self,
@ -97,11 +98,15 @@ class WorkerContext:
def get_progress_callback(self) -> ProgressCallback: def get_progress_callback(self) -> ProgressCallback:
from ..chain.pipeline import ChainProgress from ..chain.pipeline import ChainProgress
def on_progress(step: int, timestep: int, latents: Any): if self.callback is not None:
on_progress.step = step return self.callback
self.set_progress(step)
return ChainProgress.from_progress(on_progress) def on_progress(step: int, timestep: int, latents: Any):
self.callback.step = step
self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile)
self.callback = ChainProgress.from_progress(on_progress)
return self.callback
def set_cancel(self, cancel: bool = True) -> None: def set_cancel(self, cancel: bool = True) -> None:
with self.cancel.get_lock(): with self.cancel.get_lock():
@ -111,20 +116,22 @@ class WorkerContext:
with self.idle.get_lock(): with self.idle.get_lock():
self.idle.value = idle self.idle.value = idle
def set_progress(self, progress: int) -> None: def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
if self.job is None: if self.job is None:
raise RuntimeError("no job on which to set progress") raise RuntimeError("no job on which to set progress")
if self.is_cancelled(): if self.is_cancelled():
raise CancelledException("job has been cancelled") raise CancelledException("job has been cancelled")
logger.debug("setting progress for job %s to %s", self.job, progress) logger.debug("setting progress for job %s to %s", self.job, steps)
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(
self.job, self.job,
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.RUNNING, JobStatus.RUNNING,
steps=progress, steps=steps,
stages=stages,
tiles=tiles,
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,