1
0
Fork 0

track results after each stage

This commit is contained in:
Sean Sube 2024-01-03 21:31:41 -06:00
parent 4f1bc84fd9
commit 28fc2082c7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 21 additions and 4 deletions

View File

@ -27,6 +27,7 @@ class ChainProgress:
total: int total: int
stage: int stage: int
tile: int tile: int
results: int
# TODO: total stages and tiles # TODO: total stages and tiles
def __init__(self, parent: ProgressCallback, start=0) -> None: def __init__(self, parent: ProgressCallback, start=0) -> None:
@ -35,6 +36,7 @@ class ChainProgress:
self.total = 0 self.total = 0
self.stage = 0 self.stage = 0
self.tile = 0 self.tile = 0
self.results = 0
def __call__(self, step: int, timestep: int, latents: Any) -> None: def __call__(self, step: int, timestep: int, latents: Any) -> None:
if step < self.step: if step < self.step:
@ -174,6 +176,7 @@ class ChainPipeline:
) )
callback.tile = 0 callback.tile = 0
def stage_tile( def stage_tile(
source_tile: List[Image.Image], source_tile: List[Image.Image],
tile_mask: Image.Image, tile_mask: Image.Image,
@ -224,6 +227,7 @@ class ChainPipeline:
**kwargs, **kwargs,
) )
callback.results = len(stage_results)
stage_sources = StageResult(images=stage_results) stage_sources = StageResult(images=stage_results)
else: else:
logger.debug( logger.debug(
@ -279,6 +283,8 @@ class ChainPipeline:
duration, duration,
len(stage_sources), len(stage_sources),
) )
callback.results = len(stage_sources)
return stage_sources return stage_sources

View File

@ -104,7 +104,12 @@ class WorkerContext:
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
self.callback.step = step self.callback.step = step
self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile) self.set_progress(
step,
stages=self.callback.stage,
tiles=self.callback.tile,
results=self.callback.results,
)
self.callback = ChainProgress.from_progress(on_progress) self.callback = ChainProgress.from_progress(on_progress)
return self.callback return self.callback
@ -117,7 +122,9 @@ class WorkerContext:
with self.idle.get_lock(): with self.idle.get_lock():
self.idle.value = idle self.idle.value = idle
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None: def set_progress(
self, steps: int, stages: int = 0, tiles: int = 0, results: int = 0
) -> None:
if self.job is None: if self.job is None:
raise RuntimeError("no job on which to set progress") raise RuntimeError("no job on which to set progress")
@ -133,6 +140,7 @@ class WorkerContext:
steps=steps, steps=steps,
stages=stages, stages=stages,
tiles=tiles, tiles=tiles,
results=results,
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,
@ -148,8 +156,11 @@ class WorkerContext:
self.job, self.job,
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.SUCCESS, # TODO: FAILED JobStatus.SUCCESS,
steps=self.get_progress(), steps=self.last_progress.steps,
stages=self.last_progress.stages,
tiles=self.last_progress.tiles,
results=self.last_progress.results,
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,