From 4dc251cf4ae1b45dac4e1655a2559d460f1988bf Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 4 Jan 2024 19:39:44 -0600 Subject: [PATCH] fix callback access --- api/onnx_web/chain/pipeline.py | 5 ++--- api/onnx_web/worker/context.py | 16 +++++++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index 67b046ab..d58cca86 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -140,9 +140,8 @@ class ChainPipeline: callback = ChainProgress.from_progress(callback) # set estimated totals - callback.set_total( - self.steps(params, sources.size), stages=len(self.stages), tiles=0 - ) + # TODO: should use self.steps, but size is not available here + callback.set_total(params.steps, stages=len(self.stages), tiles=0) start = monotonic() diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 7e83a604..a5a6325b 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -121,8 +121,8 @@ class WorkerContext: # self.callback.step = step self.set_progress( step, - stages=self.callback.stage, - tiles=self.callback.tile, + stages=self.callback.stages.current, + tiles=self.callback.tiles.current, ) self.callback = ChainProgress.from_progress(on_progress) @@ -144,8 +144,14 @@ class WorkerContext: raise CancelledException("job has been cancelled") result = None + total_steps = 0 + total_stages = 0 + total_tiles = 0 if self.callback is not None: result = self.callback.result + total_steps = self.callback.steps.total + total_stages = self.callback.stages.total + total_tiles = self.callback.tiles.total logger.debug("setting progress for job %s to %s", self.job, steps) self.last_progress = ProgressCommand( @@ -153,9 +159,9 @@ class WorkerContext: self.job_type, self.device.device, JobStatus.RUNNING, - steps=Progress(steps, self.callback.steps.total), - stages=Progress(stages, self.callback.stages.total), - tiles=Progress(tiles, self.callback.tiles.total), + steps=Progress(steps, total_steps), + stages=Progress(stages, total_stages), + tiles=Progress(tiles, total_tiles), result=result, ) self.progress.put(