report stage and tile count in progress
This commit is contained in:
parent
9d05d9baac
commit
0c504e3f69
|
@ -22,6 +22,13 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]]
|
|||
|
||||
|
||||
class ChainProgress:
|
||||
parent: ProgressCallback
|
||||
step: int
|
||||
total: int
|
||||
stage: int
|
||||
tile: int
|
||||
# TODO: total stages and tiles
|
||||
|
||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||
self.parent = parent
|
||||
self.step = start
|
||||
|
@ -94,12 +101,8 @@ class ChainPipeline:
|
|||
|
||||
return steps
|
||||
|
||||
def outputs(self, params: ImageParams, sources: int) -> int:
|
||||
outputs = sources
|
||||
for callback, _params, kwargs in self.stages:
|
||||
outputs = callback.outputs(kwargs.get("params", params), outputs)
|
||||
|
||||
return outputs
|
||||
def stages(self) -> int:
|
||||
return len(self.stages)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -110,12 +113,11 @@ class ChainPipeline:
|
|||
callback: Optional[ProgressCallback] = None,
|
||||
**pipeline_kwargs,
|
||||
) -> StageResult:
|
||||
"""
|
||||
DEPRECATED: use `.run()` instead
|
||||
"""
|
||||
if callback is None:
|
||||
callback = worker.get_progress_callback()
|
||||
else:
|
||||
|
||||
# wrap the progress counter in a one that can be reset if needed
|
||||
if not isinstance(callback, ChainProgress):
|
||||
callback = ChainProgress.from_progress(callback)
|
||||
|
||||
start = monotonic()
|
||||
|
@ -129,7 +131,7 @@ class ChainPipeline:
|
|||
logger.info("running pipeline without source images")
|
||||
|
||||
stage_sources = sources
|
||||
for stage_pipe, stage_params, stage_kwargs in self.stages:
|
||||
for stage_i, (stage_pipe, stage_params, stage_kwargs) in enumerate(self.stages):
|
||||
name = stage_params.name or stage_pipe.__class__.__name__
|
||||
kwargs = stage_kwargs or {}
|
||||
kwargs = {**pipeline_kwargs, **kwargs}
|
||||
|
@ -139,6 +141,7 @@ class ChainPipeline:
|
|||
len(stage_sources),
|
||||
kwargs.keys(),
|
||||
)
|
||||
callback.stage = stage_i
|
||||
|
||||
per_stage_params = params
|
||||
if "params" in kwargs:
|
||||
|
@ -168,6 +171,7 @@ class ChainPipeline:
|
|||
tile,
|
||||
)
|
||||
|
||||
callback.tile = 0
|
||||
def stage_tile(
|
||||
source_tile: List[Image.Image],
|
||||
tile_mask: Image.Image,
|
||||
|
@ -191,6 +195,8 @@ class ChainPipeline:
|
|||
for j, image in enumerate(tile_result.as_image()):
|
||||
save_image(server, f"last-tile-{j}.png", image)
|
||||
|
||||
callback.tile = callback.tile + 1
|
||||
|
||||
return tile_result
|
||||
except CancelledException as err:
|
||||
worker.retries = 0
|
||||
|
|
|
@ -28,6 +28,7 @@ class WorkerContext:
|
|||
timeout: float
|
||||
retries: int
|
||||
initial_retries: int
|
||||
callback: Optional[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -97,11 +98,15 @@ class WorkerContext:
|
|||
def get_progress_callback(self) -> ProgressCallback:
|
||||
from ..chain.pipeline import ChainProgress
|
||||
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
on_progress.step = step
|
||||
self.set_progress(step)
|
||||
if self.callback is not None:
|
||||
return self.callback
|
||||
|
||||
return ChainProgress.from_progress(on_progress)
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
self.callback.step = step
|
||||
self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile)
|
||||
|
||||
self.callback = ChainProgress.from_progress(on_progress)
|
||||
return self.callback
|
||||
|
||||
def set_cancel(self, cancel: bool = True) -> None:
|
||||
with self.cancel.get_lock():
|
||||
|
@ -111,20 +116,22 @@ class WorkerContext:
|
|||
with self.idle.get_lock():
|
||||
self.idle.value = idle
|
||||
|
||||
def set_progress(self, progress: int) -> None:
|
||||
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
||||
if self.job is None:
|
||||
raise RuntimeError("no job on which to set progress")
|
||||
|
||||
if self.is_cancelled():
|
||||
raise CancelledException("job has been cancelled")
|
||||
|
||||
logger.debug("setting progress for job %s to %s", self.job, progress)
|
||||
logger.debug("setting progress for job %s to %s", self.job, steps)
|
||||
self.last_progress = ProgressCommand(
|
||||
self.job,
|
||||
self.job_type,
|
||||
self.device.device,
|
||||
JobStatus.RUNNING,
|
||||
steps=progress,
|
||||
steps=steps,
|
||||
stages=stages,
|
||||
tiles=tiles,
|
||||
)
|
||||
self.progress.put(
|
||||
self.last_progress,
|
||||
|
|
Loading…
Reference in New Issue