track results after each stage
This commit is contained in:
parent
4f1bc84fd9
commit
28fc2082c7
|
@ -27,6 +27,7 @@ class ChainProgress:
|
|||
total: int
|
||||
stage: int
|
||||
tile: int
|
||||
results: int
|
||||
# TODO: total stages and tiles
|
||||
|
||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||
|
@ -35,6 +36,7 @@ class ChainProgress:
|
|||
self.total = 0
|
||||
self.stage = 0
|
||||
self.tile = 0
|
||||
self.results = 0
|
||||
|
||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||
if step < self.step:
|
||||
|
@ -174,6 +176,7 @@ class ChainPipeline:
|
|||
)
|
||||
|
||||
callback.tile = 0
|
||||
|
||||
def stage_tile(
|
||||
source_tile: List[Image.Image],
|
||||
tile_mask: Image.Image,
|
||||
|
@ -224,6 +227,7 @@ class ChainPipeline:
|
|||
**kwargs,
|
||||
)
|
||||
|
||||
callback.results = len(stage_results)
|
||||
stage_sources = StageResult(images=stage_results)
|
||||
else:
|
||||
logger.debug(
|
||||
|
@ -279,6 +283,8 @@ class ChainPipeline:
|
|||
duration,
|
||||
len(stage_sources),
|
||||
)
|
||||
|
||||
callback.results = len(stage_sources)
|
||||
return stage_sources
|
||||
|
||||
|
||||
|
|
|
@ -104,7 +104,12 @@ class WorkerContext:
|
|||
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
self.callback.step = step
|
||||
self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile)
|
||||
self.set_progress(
|
||||
step,
|
||||
stages=self.callback.stage,
|
||||
tiles=self.callback.tile,
|
||||
results=self.callback.results,
|
||||
)
|
||||
|
||||
self.callback = ChainProgress.from_progress(on_progress)
|
||||
return self.callback
|
||||
|
@ -117,7 +122,9 @@ class WorkerContext:
|
|||
with self.idle.get_lock():
|
||||
self.idle.value = idle
|
||||
|
||||
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
||||
def set_progress(
|
||||
self, steps: int, stages: int = 0, tiles: int = 0, results: int = 0
|
||||
) -> None:
|
||||
if self.job is None:
|
||||
raise RuntimeError("no job on which to set progress")
|
||||
|
||||
|
@ -133,6 +140,7 @@ class WorkerContext:
|
|||
steps=steps,
|
||||
stages=stages,
|
||||
tiles=tiles,
|
||||
results=results,
|
||||
)
|
||||
self.progress.put(
|
||||
self.last_progress,
|
||||
|
@ -148,8 +156,11 @@ class WorkerContext:
|
|||
self.job,
|
||||
self.job_type,
|
||||
self.device.device,
|
||||
JobStatus.SUCCESS, # TODO: FAILED
|
||||
steps=self.get_progress(),
|
||||
JobStatus.SUCCESS,
|
||||
steps=self.last_progress.steps,
|
||||
stages=self.last_progress.stages,
|
||||
tiles=self.last_progress.tiles,
|
||||
results=self.last_progress.results,
|
||||
)
|
||||
self.progress.put(
|
||||
self.last_progress,
|
||||
|
|
Loading…
Reference in New Issue