From 505e408dd6a9bc7a401eababca693591c7cbdf25 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 13 Sep 2023 08:43:31 -0500 Subject: [PATCH] use pipeline output count --- api/onnx_web/chain/base.py | 7 +++++++ api/onnx_web/server/api.py | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/api/onnx_web/chain/base.py b/api/onnx_web/chain/base.py index 76d5d58b..c775776d 100644 --- a/api/onnx_web/chain/base.py +++ b/api/onnx_web/chain/base.py @@ -92,6 +92,13 @@ class ChainPipeline: return steps + def outputs(self, params: ImageParams, sources: int): + outputs = sources + for callback, _params, _kwargs in self.stages: + outputs += callback.outputs(params, outputs) + + return outputs + def __call__( self, worker: WorkerContext, diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 5af4de1c..15569fc4 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -455,7 +455,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): logger.info("running chain pipeline with %s stages", len(pipeline.stages)) output = make_output_name( - server, "chain", base_params, base_size, count=len(pipeline.stages) + server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0)) ) job_name = output[0] @@ -471,7 +471,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor): needs_device=device, ) - step_params = params.with_args(steps=pipeline.steps(base_params, base_size)) + step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size)) return jsonify(json_params(output, step_params, base_size))