From 10acad232c9e9d89f77bdf0e21118e6516784303 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Thu, 4 Jan 2024 19:48:43 -0600 Subject: [PATCH] estimate steps better, pass progress onto reply --- api/onnx_web/chain/pipeline.py | 8 ++++++-- api/onnx_web/server/api.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index d58cca86..c68f4632 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -140,8 +140,12 @@ class ChainPipeline: callback = ChainProgress.from_progress(callback) # set estimated totals - # TODO: should use self.steps, but size is not available here - callback.set_total(params.steps, stages=len(self.stages), tiles=0) + steps = params.steps + if "size" in pipeline_kwargs and isinstance(pipeline_kwargs["size"], Size): + size = pipeline_kwargs["size"] + steps = self.steps(params, size) + + callback.set_total(steps, stages=len(self.stages), tiles=0) start = monotonic() diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 02111afa..1166c0dd 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -685,9 +685,9 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): server, job_name, status, - stages=Progress(progress.stages, 0), - steps=Progress(progress.steps, 0), - tiles=Progress(progress.tiles, 0), + stages=progress.stages, + steps=progress.steps, + tiles=progress.tiles, outputs=outputs, metadata=metadata, )