1
0
Fork 0

report stage and tile count in progress

This commit is contained in:
Sean Sube 2024-01-03 21:16:44 -06:00
parent 9d05d9baac
commit 0c504e3f69
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 31 additions and 18 deletions

View File

@ -22,6 +22,13 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
class ChainProgress:
parent: ProgressCallback
step: int
total: int
stage: int
tile: int
# TODO: total stages and tiles
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
@ -94,12 +101,8 @@ class ChainPipeline:
return steps
def outputs(self, params: ImageParams, sources: int) -> int:
outputs = sources
for callback, _params, kwargs in self.stages:
outputs = callback.outputs(kwargs.get("params", params), outputs)
return outputs
def stages(self) -> int:
return len(self.stages)
def __call__(
self,
@ -110,12 +113,11 @@ class ChainPipeline:
callback: Optional[ProgressCallback] = None,
**pipeline_kwargs,
) -> StageResult:
"""
DEPRECATED: use `.run()` instead
"""
if callback is None:
callback = worker.get_progress_callback()
else:
# wrap the progress counter in a one that can be reset if needed
if not isinstance(callback, ChainProgress):
callback = ChainProgress.from_progress(callback)
start = monotonic()
@ -129,7 +131,7 @@ class ChainPipeline:
logger.info("running pipeline without source images")
stage_sources = sources
for stage_pipe, stage_params, stage_kwargs in self.stages:
for stage_i, (stage_pipe, stage_params, stage_kwargs) in enumerate(self.stages):
name = stage_params.name or stage_pipe.__class__.__name__
kwargs = stage_kwargs or {}
kwargs = {**pipeline_kwargs, **kwargs}
@ -139,6 +141,7 @@ class ChainPipeline:
len(stage_sources),
kwargs.keys(),
)
callback.stage = stage_i
per_stage_params = params
if "params" in kwargs:
@ -168,6 +171,7 @@ class ChainPipeline:
tile,
)
callback.tile = 0
def stage_tile(
source_tile: List[Image.Image],
tile_mask: Image.Image,
@ -191,6 +195,8 @@ class ChainPipeline:
for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image)
callback.tile = callback.tile + 1
return tile_result
except CancelledException as err:
worker.retries = 0

View File

@ -28,6 +28,7 @@ class WorkerContext:
timeout: float
retries: int
initial_retries: int
callback: Optional[Any]
def __init__(
self,
@ -97,11 +98,15 @@ class WorkerContext:
def get_progress_callback(self) -> ProgressCallback:
from ..chain.pipeline import ChainProgress
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
self.set_progress(step)
if self.callback is not None:
return self.callback
return ChainProgress.from_progress(on_progress)
def on_progress(step: int, timestep: int, latents: Any):
self.callback.step = step
self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile)
self.callback = ChainProgress.from_progress(on_progress)
return self.callback
def set_cancel(self, cancel: bool = True) -> None:
with self.cancel.get_lock():
@ -111,20 +116,22 @@ class WorkerContext:
with self.idle.get_lock():
self.idle.value = idle
def set_progress(self, progress: int) -> None:
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
if self.job is None:
raise RuntimeError("no job on which to set progress")
if self.is_cancelled():
raise CancelledException("job has been cancelled")
logger.debug("setting progress for job %s to %s", self.job, progress)
logger.debug("setting progress for job %s to %s", self.job, steps)
self.last_progress = ProgressCommand(
self.job,
self.job_type,
self.device.device,
JobStatus.RUNNING,
steps=progress,
steps=steps,
stages=stages,
tiles=tiles,
)
self.progress.put(
self.last_progress,