diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index ff3fae81..c52078b1 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -22,6 +22,13 @@ PipelineStage = Tuple[BaseStage, StageParams, Optional[dict]] class ChainProgress: + parent: ProgressCallback + step: int + total: int + stage: int + tile: int + # TODO: total stages and tiles + def __init__(self, parent: ProgressCallback, start=0) -> None: self.parent = parent self.step = start @@ -94,12 +101,8 @@ class ChainPipeline: return steps - def outputs(self, params: ImageParams, sources: int) -> int: - outputs = sources - for callback, _params, kwargs in self.stages: - outputs = callback.outputs(kwargs.get("params", params), outputs) - - return outputs + def stages(self) -> int: + return len(self.stages) def __call__( self, @@ -110,12 +113,11 @@ class ChainPipeline: callback: Optional[ProgressCallback] = None, **pipeline_kwargs, ) -> StageResult: - """ - DEPRECATED: use `.run()` instead - """ if callback is None: callback = worker.get_progress_callback() - else: + + # wrap the progress counter in a one that can be reset if needed + if not isinstance(callback, ChainProgress): callback = ChainProgress.from_progress(callback) start = monotonic() @@ -129,7 +131,7 @@ class ChainPipeline: logger.info("running pipeline without source images") stage_sources = sources - for stage_pipe, stage_params, stage_kwargs in self.stages: + for stage_i, (stage_pipe, stage_params, stage_kwargs) in enumerate(self.stages): name = stage_params.name or stage_pipe.__class__.__name__ kwargs = stage_kwargs or {} kwargs = {**pipeline_kwargs, **kwargs} @@ -139,6 +141,7 @@ class ChainPipeline: len(stage_sources), kwargs.keys(), ) + callback.stage = stage_i per_stage_params = params if "params" in kwargs: @@ -168,6 +171,7 @@ class ChainPipeline: tile, ) + callback.tile = 0 def stage_tile( source_tile: List[Image.Image], tile_mask: Image.Image, @@ -191,6 +195,8 @@ class ChainPipeline: for j, image in enumerate(tile_result.as_image()): save_image(server, f"last-tile-{j}.png", image) + callback.tile = callback.tile + 1 + return tile_result except CancelledException as err: worker.retries = 0 diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index e3de145c..83e3cfd0 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -28,6 +28,7 @@ class WorkerContext: timeout: float retries: int initial_retries: int + callback: Optional[Any] def __init__( self, @@ -97,11 +98,15 @@ class WorkerContext: def get_progress_callback(self) -> ProgressCallback: from ..chain.pipeline import ChainProgress - def on_progress(step: int, timestep: int, latents: Any): - on_progress.step = step - self.set_progress(step) + if self.callback is not None: + return self.callback - return ChainProgress.from_progress(on_progress) + def on_progress(step: int, timestep: int, latents: Any): + self.callback.step = step + self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile) + + self.callback = ChainProgress.from_progress(on_progress) + return self.callback def set_cancel(self, cancel: bool = True) -> None: with self.cancel.get_lock(): @@ -111,20 +116,22 @@ class WorkerContext: with self.idle.get_lock(): self.idle.value = idle - def set_progress(self, progress: int) -> None: + def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None: if self.job is None: raise RuntimeError("no job on which to set progress") if self.is_cancelled(): raise CancelledException("job has been cancelled") - logger.debug("setting progress for job %s to %s", self.job, progress) + logger.debug("setting progress for job %s to %s", self.job, steps) self.last_progress = ProgressCommand( self.job, self.job_type, self.device.device, JobStatus.RUNNING, - steps=progress, + steps=steps, + stages=stages, + tiles=tiles, ) self.progress.put( self.last_progress,