use pipeline output count
This commit is contained in:
parent
046de9bf3a
commit
505e408dd6
|
@ -92,6 +92,13 @@ class ChainPipeline:
|
|||
|
||||
return steps
|
||||
|
||||
def outputs(self, params: ImageParams, sources: int):
|
||||
outputs = sources
|
||||
for callback, _params, _kwargs in self.stages:
|
||||
outputs += callback.outputs(params, outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
worker: WorkerContext,
|
||||
|
|
|
@ -455,7 +455,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||
|
||||
output = make_output_name(
|
||||
server, "chain", base_params, base_size, count=len(pipeline.stages)
|
||||
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0))
|
||||
)
|
||||
job_name = output[0]
|
||||
|
||||
|
@ -471,7 +471,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
|||
needs_device=device,
|
||||
)
|
||||
|
||||
step_params = params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||
return jsonify(json_params(output, step_params, base_size))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue