use pipeline output count
This commit is contained in:
parent
046de9bf3a
commit
505e408dd6
|
@ -92,6 +92,13 @@ class ChainPipeline:
|
||||||
|
|
||||||
return steps
|
return steps
|
||||||
|
|
||||||
|
def outputs(self, params: ImageParams, sources: int):
|
||||||
|
outputs = sources
|
||||||
|
for callback, _params, _kwargs in self.stages:
|
||||||
|
outputs += callback.outputs(params, outputs)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
worker: WorkerContext,
|
worker: WorkerContext,
|
||||||
|
|
|
@ -455,7 +455,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
logger.info("running chain pipeline with %s stages", len(pipeline.stages))
|
||||||
|
|
||||||
output = make_output_name(
|
output = make_output_name(
|
||||||
server, "chain", base_params, base_size, count=len(pipeline.stages)
|
server, "chain", base_params, base_size, count=pipeline.outputs(base_params, 0))
|
||||||
)
|
)
|
||||||
job_name = output[0]
|
job_name = output[0]
|
||||||
|
|
||||||
|
@ -471,7 +471,7 @@ def chain(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
needs_device=device,
|
needs_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
step_params = params.with_args(steps=pipeline.steps(base_params, base_size))
|
step_params = base_params.with_args(steps=pipeline.steps(base_params, base_size))
|
||||||
return jsonify(json_params(output, step_params, base_size))
|
return jsonify(json_params(output, step_params, base_size))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue