From 28fc2082c736f5a047b5257c7d601ca2557d7a20 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 3 Jan 2024 21:31:41 -0600 Subject: [PATCH] track results after each stage --- api/onnx_web/chain/pipeline.py | 6 ++++++ api/onnx_web/worker/context.py | 19 +++++++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index b1f01ced..a36248a0 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -27,6 +27,7 @@ class ChainProgress: total: int stage: int tile: int + results: int # TODO: total stages and tiles def __init__(self, parent: ProgressCallback, start=0) -> None: @@ -35,6 +36,7 @@ class ChainProgress: self.total = 0 self.stage = 0 self.tile = 0 + self.results = 0 def __call__(self, step: int, timestep: int, latents: Any) -> None: if step < self.step: @@ -174,6 +176,7 @@ class ChainPipeline: ) callback.tile = 0 + def stage_tile( source_tile: List[Image.Image], tile_mask: Image.Image, @@ -224,6 +227,7 @@ class ChainPipeline: **kwargs, ) + callback.results = len(stage_results) stage_sources = StageResult(images=stage_results) else: logger.debug( @@ -279,6 +283,8 @@ class ChainPipeline: duration, len(stage_sources), ) + + callback.results = len(stage_sources) return stage_sources diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index a5bd10d7..2e832169 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -104,7 +104,12 @@ class WorkerContext: def on_progress(step: int, timestep: int, latents: Any): self.callback.step = step - self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile) + self.set_progress( + step, + stages=self.callback.stage, + tiles=self.callback.tile, + results=self.callback.results, + ) self.callback = ChainProgress.from_progress(on_progress) return self.callback @@ -117,7 +122,9 @@ class WorkerContext: with self.idle.get_lock(): self.idle.value = idle - def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None: + def set_progress( + self, steps: int, stages: int = 0, tiles: int = 0, results: int = 0 + ) -> None: if self.job is None: raise RuntimeError("no job on which to set progress") @@ -133,6 +140,7 @@ class WorkerContext: steps=steps, stages=stages, tiles=tiles, + results=results, ) self.progress.put( self.last_progress, @@ -148,8 +156,11 @@ class WorkerContext: self.job, self.job_type, self.device.device, - JobStatus.SUCCESS, # TODO: FAILED - steps=self.get_progress(), + JobStatus.SUCCESS, + steps=self.last_progress.steps, + stages=self.last_progress.stages, + tiles=self.last_progress.tiles, + results=self.last_progress.results, ) self.progress.put( self.last_progress,