keep result of each stage with metadata
This commit is contained in:
parent
e0d0933092
commit
fac98ab239
|
@ -27,7 +27,7 @@ class ChainProgress:
|
||||||
total: int
|
total: int
|
||||||
stage: int
|
stage: int
|
||||||
tile: int
|
tile: int
|
||||||
results: int
|
result: Optional[StageResult]
|
||||||
# TODO: total stages and tiles
|
# TODO: total stages and tiles
|
||||||
|
|
||||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||||
|
@ -36,7 +36,7 @@ class ChainProgress:
|
||||||
self.total = 0
|
self.total = 0
|
||||||
self.stage = 0
|
self.stage = 0
|
||||||
self.tile = 0
|
self.tile = 0
|
||||||
self.results = 0
|
self.result = None
|
||||||
|
|
||||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||||
if step < self.step:
|
if step < self.step:
|
||||||
|
@ -169,14 +169,13 @@ class ChainPipeline:
|
||||||
if stage_pipe.max_tile > 0:
|
if stage_pipe.max_tile > 0:
|
||||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||||
|
|
||||||
|
callback.tile = 0 # reset this either way
|
||||||
if must_tile:
|
if must_tile:
|
||||||
logger.info(
|
logger.info(
|
||||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||||
tile,
|
tile,
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.tile = 0
|
|
||||||
|
|
||||||
def stage_tile(
|
def stage_tile(
|
||||||
source_tile: List[Image.Image],
|
source_tile: List[Image.Image],
|
||||||
tile_mask: Image.Image,
|
tile_mask: Image.Image,
|
||||||
|
@ -227,7 +226,6 @@ class ChainPipeline:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.results = len(stage_results)
|
|
||||||
stage_sources = StageResult(images=stage_results)
|
stage_sources = StageResult(images=stage_results)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -272,6 +270,8 @@ class ChainPipeline:
|
||||||
len(stage_sources),
|
len(stage_sources),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
callback.result = stage_sources # this has just been set to the result of the last stage
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
for j, image in enumerate(stage_sources.as_image()):
|
for j, image in enumerate(stage_sources.as_image()):
|
||||||
save_image(server, f"last-stage-{j}.png", image)
|
save_image(server, f"last-stage-{j}.png", image)
|
||||||
|
@ -284,7 +284,7 @@ class ChainPipeline:
|
||||||
len(stage_sources),
|
len(stage_sources),
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.results = len(stage_sources)
|
callback.result = stage_sources
|
||||||
return stage_sources
|
return stage_sources
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -130,14 +130,16 @@ def image_reply(
|
||||||
}
|
}
|
||||||
|
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
data["outputs"] = outputs
|
if metadata is None:
|
||||||
|
logger.error("metadata is required with outputs")
|
||||||
|
return error_reply("metadata is required with outputs")
|
||||||
|
|
||||||
if metadata is not None:
|
|
||||||
if len(metadata) != len(outputs):
|
if len(metadata) != len(outputs):
|
||||||
logger.error("metadata and outputs must be the same length")
|
logger.error("metadata and outputs must be the same length")
|
||||||
return error_reply("metadata and outputs must be the same length")
|
return error_reply("metadata and outputs must be the same length")
|
||||||
|
|
||||||
data["metadata"] = metadata
|
data["metadata"] = metadata
|
||||||
|
data["outputs"] = outputs
|
||||||
|
|
||||||
return jsonify([data])
|
return jsonify([data])
|
||||||
|
|
||||||
|
@ -669,8 +671,11 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
# TODO: accumulate results
|
# TODO: accumulate results
|
||||||
if progress is not None:
|
if progress is not None:
|
||||||
outputs = None
|
outputs = None
|
||||||
if progress.results > 0:
|
metadata = None
|
||||||
outputs = make_output_names(server, job_name, progress.results)
|
if progress.result is not None and len(progress.result) > 0:
|
||||||
|
# TODO: progress results should be a list of filenames and image metadata
|
||||||
|
outputs = make_output_names(server, job_name, len(progress.result))
|
||||||
|
metadata = progress.result.metadata
|
||||||
|
|
||||||
return image_reply(
|
return image_reply(
|
||||||
job_name,
|
job_name,
|
||||||
|
@ -680,6 +685,7 @@ def job_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
steps=Progress(progress.steps, 0),
|
steps=Progress(progress.steps, 0),
|
||||||
tiles=Progress(progress.tiles, 0),
|
tiles=Progress(progress.tiles, 0),
|
||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
return image_reply(job_name, status, "TODO")
|
return image_reply(job_name, status, "TODO")
|
||||||
|
|
Loading…
Reference in New Issue