track results after each stage
This commit is contained in:
parent
4f1bc84fd9
commit
28fc2082c7
|
@ -27,6 +27,7 @@ class ChainProgress:
|
||||||
total: int
|
total: int
|
||||||
stage: int
|
stage: int
|
||||||
tile: int
|
tile: int
|
||||||
|
results: int
|
||||||
# TODO: total stages and tiles
|
# TODO: total stages and tiles
|
||||||
|
|
||||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||||
|
@ -35,6 +36,7 @@ class ChainProgress:
|
||||||
self.total = 0
|
self.total = 0
|
||||||
self.stage = 0
|
self.stage = 0
|
||||||
self.tile = 0
|
self.tile = 0
|
||||||
|
self.results = 0
|
||||||
|
|
||||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||||
if step < self.step:
|
if step < self.step:
|
||||||
|
@ -174,6 +176,7 @@ class ChainPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
callback.tile = 0
|
callback.tile = 0
|
||||||
|
|
||||||
def stage_tile(
|
def stage_tile(
|
||||||
source_tile: List[Image.Image],
|
source_tile: List[Image.Image],
|
||||||
tile_mask: Image.Image,
|
tile_mask: Image.Image,
|
||||||
|
@ -224,6 +227,7 @@ class ChainPipeline:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
callback.results = len(stage_results)
|
||||||
stage_sources = StageResult(images=stage_results)
|
stage_sources = StageResult(images=stage_results)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -279,6 +283,8 @@ class ChainPipeline:
|
||||||
duration,
|
duration,
|
||||||
len(stage_sources),
|
len(stage_sources),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
callback.results = len(stage_sources)
|
||||||
return stage_sources
|
return stage_sources
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,12 @@ class WorkerContext:
|
||||||
|
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
self.callback.step = step
|
self.callback.step = step
|
||||||
self.set_progress(step, stages=self.callback.stage, tiles=self.callback.tile)
|
self.set_progress(
|
||||||
|
step,
|
||||||
|
stages=self.callback.stage,
|
||||||
|
tiles=self.callback.tile,
|
||||||
|
results=self.callback.results,
|
||||||
|
)
|
||||||
|
|
||||||
self.callback = ChainProgress.from_progress(on_progress)
|
self.callback = ChainProgress.from_progress(on_progress)
|
||||||
return self.callback
|
return self.callback
|
||||||
|
@ -117,7 +122,9 @@ class WorkerContext:
|
||||||
with self.idle.get_lock():
|
with self.idle.get_lock():
|
||||||
self.idle.value = idle
|
self.idle.value = idle
|
||||||
|
|
||||||
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
def set_progress(
|
||||||
|
self, steps: int, stages: int = 0, tiles: int = 0, results: int = 0
|
||||||
|
) -> None:
|
||||||
if self.job is None:
|
if self.job is None:
|
||||||
raise RuntimeError("no job on which to set progress")
|
raise RuntimeError("no job on which to set progress")
|
||||||
|
|
||||||
|
@ -133,6 +140,7 @@ class WorkerContext:
|
||||||
steps=steps,
|
steps=steps,
|
||||||
stages=stages,
|
stages=stages,
|
||||||
tiles=tiles,
|
tiles=tiles,
|
||||||
|
results=results,
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
|
@ -148,8 +156,11 @@ class WorkerContext:
|
||||||
self.job,
|
self.job,
|
||||||
self.job_type,
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
JobStatus.SUCCESS, # TODO: FAILED
|
JobStatus.SUCCESS,
|
||||||
steps=self.get_progress(),
|
steps=self.last_progress.steps,
|
||||||
|
stages=self.last_progress.stages,
|
||||||
|
tiles=self.last_progress.tiles,
|
||||||
|
results=self.last_progress.results,
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
|
|
Loading…
Reference in New Issue