1
0
Fork 0

move progress state to worker context

This commit is contained in:
Sean Sube 2024-01-05 20:12:41 -06:00
parent 4f230f4111
commit 9b5d00a66a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 75 additions and 54 deletions

View File

@ -26,19 +26,13 @@ class ChainProgress:
step: int # same as steps.current, left for legacy purposes step: int # same as steps.current, left for legacy purposes
prev: int # accumulator when step resets prev: int # accumulator when step resets
# new progress trackers # TODO: should probably be moved to worker context as well
steps: Progress
stages: Progress
tiles: Progress
result: Optional[StageResult] result: Optional[StageResult]
def __init__(self, parent: ProgressCallback, start=0) -> None: def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent self.parent = parent
self.step = start self.step = start
self.prev = 0 self.prev = 0
self.steps = Progress(self.step, self.prev)
self.stages = Progress(0, 0)
self.tiles = Progress(0, 0)
self.result = None self.result = None
def __call__(self, step: int, timestep: int, latents: Any) -> None: def __call__(self, step: int, timestep: int, latents: Any) -> None:
@ -47,20 +41,11 @@ class ChainProgress:
self.prev += self.step self.prev += self.step
self.step = step self.step = step
self.parent(self.get_total(), timestep, latents)
total = self.get_total()
self.steps.current = total
self.parent(total, timestep, latents)
def get_total(self) -> int: def get_total(self) -> int:
return self.step + self.prev return self.step + self.prev
def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
self.prev = steps
self.steps.total = steps
self.stages.total = stages
self.tiles.total = tiles
@classmethod @classmethod
def from_progress(cls, parent: ProgressCallback): def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0 start = parent.step if hasattr(parent, "step") else 0
@ -144,7 +129,7 @@ class ChainPipeline:
size = pipeline_kwargs["size"] size = pipeline_kwargs["size"]
steps = self.steps(params, size) steps = self.steps(params, size)
callback.set_total(steps, stages=len(self.stages), tiles=0) worker.set_totals(steps, stages=len(self.stages), tiles=0)
start = monotonic() start = monotonic()
@ -167,7 +152,7 @@ class ChainPipeline:
len(stage_sources), len(stage_sources),
kwargs.keys(), kwargs.keys(),
) )
callback.stages.current = stage_i worker.set_stages(stage_i)
per_stage_params = params per_stage_params = params
if "params" in kwargs: if "params" in kwargs:
@ -191,7 +176,7 @@ class ChainPipeline:
if stage_pipe.max_tile > 0: if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size) tile = min(stage_pipe.max_tile, stage_params.tile_size)
callback.tiles.current = 0 # reset this either way worker.set_tiles(0)
if must_tile: if must_tile:
logger.info( logger.info(
"image contains sources or is larger than tile size of %s, tiling stage", "image contains sources or is larger than tile size of %s, tiling stage",
@ -223,7 +208,7 @@ class ChainPipeline:
for j, image in enumerate(tile_result.as_image()): for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image) save_image(server, f"last-tile-{j}.png", image)
callback.tiles.current = callback.tiles.current + 1 worker.set_tiles(worker.tiles.current + 1)
return tile_result return tile_result
except CancelledException as err: except CancelledException as err:

View File

@ -38,6 +38,13 @@ class Progress:
"total": self.total, "total": self.total,
} }
def complete(self) -> bool:
return self.current >= self.total
def empty(self) -> bool:
# TODO: what if total is also 0?
return self.current == 0
class ProgressCommand: class ProgressCommand:
device: str device: str

View File

@ -30,6 +30,11 @@ class WorkerContext:
initial_retries: int initial_retries: int
callback: Optional[Any] callback: Optional[Any]
# progress state
steps: Progress
stages: Progress
tiles: Progress
def __init__( def __init__(
self, self,
name: str, name: str,
@ -58,6 +63,9 @@ class WorkerContext:
self.retries = retries self.retries = retries
self.timeout = timeout self.timeout = timeout
self.callback = None self.callback = None
self.steps = Progress(0, 0)
self.stages = Progress(0, 0)
self.tiles = Progress(0, 0)
def start(self, job: JobCommand) -> None: def start(self, job: JobCommand) -> None:
# set job name and type # set job name and type
@ -94,22 +102,13 @@ class WorkerContext:
return self.get_last_steps() return self.get_last_steps()
def get_last_steps(self) -> Progress: def get_last_steps(self) -> Progress:
if self.last_progress is not None: return self.steps
return self.last_progress.steps
return Progress(0, 0)
def get_last_stages(self) -> Progress: def get_last_stages(self) -> Progress:
if self.last_progress is not None: return self.stages
return self.last_progress.stages
return Progress(0, 0)
def get_last_tiles(self) -> Progress: def get_last_tiles(self) -> Progress:
if self.last_progress is not None: return self.tiles
return self.last_progress.tiles
return Progress(0, 0)
def get_progress_callback(self, reset=False) -> ProgressCallback: def get_progress_callback(self, reset=False) -> ProgressCallback:
from ..chain.pipeline import ChainProgress from ..chain.pipeline import ChainProgress
@ -118,11 +117,8 @@ class WorkerContext:
return self.callback return self.callback
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
# self.callback.step = step
self.set_progress( self.set_progress(
step, step,
stages=self.callback.stages.current,
tiles=self.callback.tiles.current,
) )
self.callback = ChainProgress.from_progress(on_progress) self.callback = ChainProgress.from_progress(on_progress)
@ -136,32 +132,37 @@ class WorkerContext:
with self.idle.get_lock(): with self.idle.get_lock():
self.idle.value = idle self.idle.value = idle
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None: def set_progress(self, steps: int, stages: int = None, tiles: int = None) -> None:
if self.job is None: if self.job is None:
raise RuntimeError("no job on which to set progress") raise RuntimeError("no job on which to set progress")
if self.is_cancelled(): if self.is_cancelled():
raise CancelledException("job has been cancelled") raise CancelledException("job has been cancelled")
# update current progress counters
self.steps.current = steps
if stages is not None:
self.stages.current = stages
if tiles is not None:
self.tiles.current = tiles
# TODO: result should really be part of context at this point
result = None result = None
total_steps = 0
total_stages = 0
total_tiles = 0
if self.callback is not None: if self.callback is not None:
result = self.callback.result result = self.callback.result
total_steps = self.callback.steps.total
total_stages = self.callback.stages.total
total_tiles = self.callback.tiles.total
# send progress to worker pool
logger.debug("setting progress for job %s to %s", self.job, steps) logger.debug("setting progress for job %s to %s", self.job, steps)
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(
self.job, self.job,
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.RUNNING, JobStatus.RUNNING,
steps=Progress(steps, total_steps), steps=self.steps,
stages=Progress(stages, total_stages), stages=self.stages,
tiles=Progress(tiles, total_tiles), tiles=self.tiles,
result=result, result=result,
) )
self.progress.put( self.progress.put(
@ -169,20 +170,48 @@ class WorkerContext:
block=False, block=False,
) )
def set_steps(self, current: int, total: int = 0) -> None:
if total > 0:
self.steps = Progress(current, total)
else:
self.steps.current = current
def set_stages(self, current: int, total: int = 0) -> None:
if total > 0:
self.stages = Progress(current, total)
else:
self.stages.current = current
def set_tiles(self, current: int, total: int = 0) -> None:
if total > 0:
self.tiles = Progress(current, total)
else:
self.tiles.current = current
def set_totals(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
self.steps.total = steps
self.stages.total = stages
self.tiles.total = tiles
def finish(self) -> None: def finish(self) -> None:
if self.job is None: if self.job is None:
logger.warning("setting finished without an active job") logger.warning("setting finished without an active job")
else: else:
logger.debug("setting finished for job %s", self.job) logger.debug("setting finished for job %s", self.job)
result = None
if self.callback is not None:
result = self.callback.result
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(
self.job, self.job,
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.SUCCESS, JobStatus.SUCCESS,
steps=self.callback.steps, steps=self.steps,
stages=self.callback.stages, stages=self.stages,
tiles=self.callback.tiles, tiles=self.tiles,
result=self.callback.result, result=result,
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,
@ -200,10 +229,10 @@ class WorkerContext:
self.job_type, self.job_type,
self.device.device, self.device.device,
JobStatus.FAILED, JobStatus.FAILED,
steps=self.get_last_steps(), steps=self.steps,
stages=self.get_last_stages(), stages=self.stages,
tiles=self.get_last_tiles(), tiles=self.tiles,
# TODO: should this have results? # TODO: should this include partial results?
) )
self.progress.put( self.progress.put(
self.last_progress, self.last_progress,