From fac98ab239341df19fada55c86b9ad42d86aa6c7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 3 Jan 2024 23:13:21 -0600 Subject: [PATCH] keep result of each stage with metadata --- api/onnx_web/chain/pipeline.py | 12 ++++++------ api/onnx_web/server/api.py | 24 +++++++++++++++--------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index a36248a0..26b76fad 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -27,7 +27,7 @@ class ChainProgress: total: int stage: int tile: int - results: int + result: Optional[StageResult] # TODO: total stages and tiles def __init__(self, parent: ProgressCallback, start=0) -> None: @@ -36,7 +36,7 @@ class ChainProgress: self.total = 0 self.stage = 0 self.tile = 0 - self.results = 0 + self.result = None def __call__(self, step: int, timestep: int, latents: Any) -> None: if step < self.step: @@ -169,14 +169,13 @@ class ChainPipeline: if stage_pipe.max_tile > 0: tile = min(stage_pipe.max_tile, stage_params.tile_size) + callback.tile = 0 # reset this either way if must_tile: logger.info( "image contains sources or is larger than tile size of %s, tiling stage", tile, ) - callback.tile = 0 - def stage_tile( source_tile: List[Image.Image], tile_mask: Image.Image, @@ -227,7 +226,6 @@ class ChainPipeline: **kwargs, ) - callback.results = len(stage_results) stage_sources = StageResult(images=stage_results) else: logger.debug( @@ -272,6 +270,8 @@ class ChainPipeline: len(stage_sources), ) + callback.result = stage_sources # this has just been set to the result of the last stage + if is_debug(): for j, image in enumerate(stage_sources.as_image()): save_image(server, f"last-stage-{j}.png", image) @@ -284,7 +284,7 @@ class ChainPipeline: len(stage_sources), ) - callback.results = len(stage_sources) + callback.result = stage_sources return stage_sources diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 58c904aa..a0310b24 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -130,15 +130,17 @@ def image_reply( } if outputs is not None: + if metadata is None: + logger.error("metadata is required with outputs") + return error_reply("metadata is required with outputs") + + if len(metadata) != len(outputs): + logger.error("metadata and outputs must be the same length") + return error_reply("metadata and outputs must be the same length") + + data["metadata"] = metadata data["outputs"] = outputs - if metadata is not None: - if len(metadata) != len(outputs): - logger.error("metadata and outputs must be the same length") - return error_reply("metadata and outputs must be the same length") - - data["metadata"] = metadata - return jsonify([data]) @@ -669,8 +671,11 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): # TODO: accumulate results if progress is not None: outputs = None - if progress.results > 0: - outputs = make_output_names(server, job_name, progress.results) + metadata = None + if progress.result is not None and len(progress.result) > 0: + # TODO: progress results should be a list of filenames and image metadata + outputs = make_output_names(server, job_name, len(progress.result)) + metadata = progress.result.metadata return image_reply( job_name, @@ -680,6 +685,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor): steps=Progress(progress.steps, 0), tiles=Progress(progress.tiles, 0), outputs=outputs, + metadata=metadata, ) return image_reply(job_name, status, "TODO")