keep result of each stage with metadata
This commit is contained in:
parent
e0d0933092
commit
fac98ab239
|
@ -27,7 +27,7 @@ class ChainProgress:
|
|||
total: int
|
||||
stage: int
|
||||
tile: int
|
||||
results: int
|
||||
result: Optional[StageResult]
|
||||
# TODO: total stages and tiles
|
||||
|
||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||
|
@ -36,7 +36,7 @@ class ChainProgress:
|
|||
self.total = 0
|
||||
self.stage = 0
|
||||
self.tile = 0
|
||||
self.results = 0
|
||||
self.result = None
|
||||
|
||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||
if step < self.step:
|
||||
|
@ -169,14 +169,13 @@ class ChainPipeline:
|
|||
if stage_pipe.max_tile > 0:
|
||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||
|
||||
callback.tile = 0 # reset this either way
|
||||
if must_tile:
|
||||
logger.info(
|
||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||
tile,
|
||||
)
|
||||
|
||||
callback.tile = 0
|
||||
|
||||
def stage_tile(
|
||||
source_tile: List[Image.Image],
|
||||
tile_mask: Image.Image,
|
||||
|
@ -227,7 +226,6 @@ class ChainPipeline:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
callback.results = len(stage_results)
|
||||
stage_sources = StageResult(images=stage_results)
|
||||
else:
|
||||
logger.debug(
|
||||
|
@ -272,6 +270,8 @@ class ChainPipeline:
|
|||
len(stage_sources),
|
||||
)
|
||||
|
||||
callback.result = stage_sources # this has just been set to the result of the last stage
|
||||
|
||||
if is_debug():
|
||||
for j, image in enumerate(stage_sources.as_image()):
|
||||
save_image(server, f"last-stage-{j}.png", image)
|
||||
|
@ -284,7 +284,7 @@ class ChainPipeline:
|
|||
len(stage_sources),
|
||||
)
|
||||
|
||||
callback.results = len(stage_sources)
|
||||
callback.result = stage_sources
|
||||
return stage_sources
|
||||
|
||||
|
||||
|
|
|
@ -130,15 +130,17 @@ def image_reply(
|
|||
}
|
||||
|
||||
if outputs is not None:
|
||||
if metadata is None:
|
||||
logger.error("metadata is required with outputs")
|
||||
return error_reply("metadata is required with outputs")
|
||||
|
||||
if len(metadata) != len(outputs):
|
||||
logger.error("metadata and outputs must be the same length")
|
||||
return error_reply("metadata and outputs must be the same length")
|
||||
|
||||
data["metadata"] = metadata
|
||||
data["outputs"] = outputs
|
||||
|
||||
if metadata is not None:
|
||||
if len(metadata) != len(outputs):
|
||||
logger.error("metadata and outputs must be the same length")
|
||||
return error_reply("metadata and outputs must be the same length")
|
||||
|
||||
data["metadata"] = metadata
|
||||
|
||||
return jsonify([data])
|
||||
|
||||
|
||||
|
@ -669,8 +671,11 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
|||
# TODO: accumulate results
|
||||
if progress is not None:
|
||||
outputs = None
|
||||
if progress.results > 0:
|
||||
outputs = make_output_names(server, job_name, progress.results)
|
||||
metadata = None
|
||||
if progress.result is not None and len(progress.result) > 0:
|
||||
# TODO: progress results should be a list of filenames and image metadata
|
||||
outputs = make_output_names(server, job_name, len(progress.result))
|
||||
metadata = progress.result.metadata
|
||||
|
||||
return image_reply(
|
||||
job_name,
|
||||
|
@ -680,6 +685,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
|||
steps=Progress(progress.steps, 0),
|
||||
tiles=Progress(progress.tiles, 0),
|
||||
outputs=outputs,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return image_reply(job_name, status, "TODO")
|
||||
|
|
Loading…
Reference in New Issue