1
0
Fork 0

keep result of each stage with metadata

This commit is contained in:
Sean Sube 2024-01-03 23:13:21 -06:00
parent e0d0933092
commit fac98ab239
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 21 additions and 15 deletions

View File

@ -27,7 +27,7 @@ class ChainProgress:
total: int total: int
stage: int stage: int
tile: int tile: int
results: int result: Optional[StageResult]
# TODO: total stages and tiles # TODO: total stages and tiles
def __init__(self, parent: ProgressCallback, start=0) -> None: def __init__(self, parent: ProgressCallback, start=0) -> None:
@ -36,7 +36,7 @@ class ChainProgress:
self.total = 0 self.total = 0
self.stage = 0 self.stage = 0
self.tile = 0 self.tile = 0
self.results = 0 self.result = None
def __call__(self, step: int, timestep: int, latents: Any) -> None: def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step: if step < self.step:
@ -169,14 +169,13 @@ class ChainPipeline:
if stage_pipe.max_tile > 0: if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size) tile = min(stage_pipe.max_tile, stage_params.tile_size)
callback.tile = 0 # reset this either way
if must_tile: if must_tile:
logger.info( logger.info(
"image contains sources or is larger than tile size of %s, tiling stage", "image contains sources or is larger than tile size of %s, tiling stage",
tile, tile,
) )
callback.tile = 0
def stage_tile( def stage_tile(
source_tile: List[Image.Image], source_tile: List[Image.Image],
tile_mask: Image.Image, tile_mask: Image.Image,
@ -227,7 +226,6 @@ class ChainPipeline:
**kwargs, **kwargs,
) )
callback.results = len(stage_results)
stage_sources = StageResult(images=stage_results) stage_sources = StageResult(images=stage_results)
else: else:
logger.debug( logger.debug(
@ -272,6 +270,8 @@ class ChainPipeline:
len(stage_sources), len(stage_sources),
) )
callback.result = stage_sources # this has just been set to the result of the last stage
if is_debug(): if is_debug():
for j, image in enumerate(stage_sources.as_image()): for j, image in enumerate(stage_sources.as_image()):
save_image(server, f"last-stage-{j}.png", image) save_image(server, f"last-stage-{j}.png", image)
@ -284,7 +284,7 @@ class ChainPipeline:
len(stage_sources), len(stage_sources),
) )
callback.results = len(stage_sources) callback.result = stage_sources
return stage_sources return stage_sources

View File

@ -130,14 +130,16 @@ def image_reply(
} }
if outputs is not None: if outputs is not None:
data["outputs"] = outputs if metadata is None:
logger.error("metadata is required with outputs")
return error_reply("metadata is required with outputs")
if metadata is not None:
if len(metadata) != len(outputs): if len(metadata) != len(outputs):
logger.error("metadata and outputs must be the same length") logger.error("metadata and outputs must be the same length")
return error_reply("metadata and outputs must be the same length") return error_reply("metadata and outputs must be the same length")
data["metadata"] = metadata data["metadata"] = metadata
data["outputs"] = outputs
return jsonify([data]) return jsonify([data])
@ -669,8 +671,11 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
# TODO: accumulate results # TODO: accumulate results
if progress is not None: if progress is not None:
outputs = None outputs = None
if progress.results > 0: metadata = None
outputs = make_output_names(server, job_name, progress.results) if progress.result is not None and len(progress.result) > 0:
# TODO: progress results should be a list of filenames and image metadata
outputs = make_output_names(server, job_name, len(progress.result))
metadata = progress.result.metadata
return image_reply( return image_reply(
job_name, job_name,
@ -680,6 +685,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
steps=Progress(progress.steps, 0), steps=Progress(progress.steps, 0),
tiles=Progress(progress.tiles, 0), tiles=Progress(progress.tiles, 0),
outputs=outputs, outputs=outputs,
metadata=metadata,
) )
return image_reply(job_name, status, "TODO") return image_reply(job_name, status, "TODO")