diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index d58cca86..c68f4632 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -140,8 +140,12 @@ class ChainPipeline: callback = ChainProgress.from_progress(callback) # set estimated totals - # TODO: should use self.steps, but size is not available here - callback.set_total(params.steps, stages=len(self.stages), tiles=0) + steps = params.steps + if "size" in pipeline_kwargs and isinstance(pipeline_kwargs["size"], Size): + size = pipeline_kwargs["size"] + steps = self.steps(params, size) + + callback.set_total(steps, stages=len(self.stages), tiles=0) start = monotonic() diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 02111afa..1166c0dd 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -685,9 +685,9 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): server, job_name, status, - stages=Progress(progress.stages, 0), - steps=Progress(progress.steps, 0), - tiles=Progress(progress.tiles, 0), + stages=progress.stages, + steps=progress.steps, + tiles=progress.tiles, outputs=outputs, metadata=metadata, )