1
0
Fork 0

move progress state to worker context

This commit is contained in:
Sean Sube 2024-01-05 20:12:41 -06:00
parent 4f230f4111
commit 9b5d00a66a
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 75 additions and 54 deletions

View File

@ -26,19 +26,13 @@ class ChainProgress:
step: int # same as steps.current, left for legacy purposes
prev: int # accumulator when step resets
# new progress trackers
steps: Progress
stages: Progress
tiles: Progress
# TODO: should probably be moved to worker context as well
result: Optional[StageResult]
def __init__(self, parent: ProgressCallback, start=0) -> None:
self.parent = parent
self.step = start
self.prev = 0
self.steps = Progress(self.step, self.prev)
self.stages = Progress(0, 0)
self.tiles = Progress(0, 0)
self.result = None
def __call__(self, step: int, timestep: int, latents: Any) -> None:
@ -47,20 +41,11 @@ class ChainProgress:
self.prev += self.step
self.step = step
total = self.get_total()
self.steps.current = total
self.parent(total, timestep, latents)
self.parent(self.get_total(), timestep, latents)
def get_total(self) -> int:
return self.step + self.prev
def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
self.prev = steps
self.steps.total = steps
self.stages.total = stages
self.tiles.total = tiles
@classmethod
def from_progress(cls, parent: ProgressCallback):
start = parent.step if hasattr(parent, "step") else 0
@ -144,7 +129,7 @@ class ChainPipeline:
size = pipeline_kwargs["size"]
steps = self.steps(params, size)
callback.set_total(steps, stages=len(self.stages), tiles=0)
worker.set_totals(steps, stages=len(self.stages), tiles=0)
start = monotonic()
@ -167,7 +152,7 @@ class ChainPipeline:
len(stage_sources),
kwargs.keys(),
)
callback.stages.current = stage_i
worker.set_stages(stage_i)
per_stage_params = params
if "params" in kwargs:
@ -191,7 +176,7 @@ class ChainPipeline:
if stage_pipe.max_tile > 0:
tile = min(stage_pipe.max_tile, stage_params.tile_size)
callback.tiles.current = 0 # reset this either way
worker.set_tiles(0)
if must_tile:
logger.info(
"image contains sources or is larger than tile size of %s, tiling stage",
@ -223,7 +208,7 @@ class ChainPipeline:
for j, image in enumerate(tile_result.as_image()):
save_image(server, f"last-tile-{j}.png", image)
callback.tiles.current = callback.tiles.current + 1
worker.set_tiles(worker.tiles.current + 1)
return tile_result
except CancelledException as err:

View File

@ -38,6 +38,13 @@ class Progress:
"total": self.total,
}
def complete(self) -> bool:
return self.current >= self.total
def empty(self) -> bool:
# TODO: what if total is also 0?
return self.current == 0
class ProgressCommand:
device: str

View File

@ -30,6 +30,11 @@ class WorkerContext:
initial_retries: int
callback: Optional[Any]
# progress state
steps: Progress
stages: Progress
tiles: Progress
def __init__(
self,
name: str,
@ -58,6 +63,9 @@ class WorkerContext:
self.retries = retries
self.timeout = timeout
self.callback = None
self.steps = Progress(0, 0)
self.stages = Progress(0, 0)
self.tiles = Progress(0, 0)
def start(self, job: JobCommand) -> None:
# set job name and type
@ -94,22 +102,13 @@ class WorkerContext:
return self.get_last_steps()
def get_last_steps(self) -> Progress:
if self.last_progress is not None:
return self.last_progress.steps
return Progress(0, 0)
return self.steps
def get_last_stages(self) -> Progress:
if self.last_progress is not None:
return self.last_progress.stages
return Progress(0, 0)
return self.stages
def get_last_tiles(self) -> Progress:
if self.last_progress is not None:
return self.last_progress.tiles
return Progress(0, 0)
return self.tiles
def get_progress_callback(self, reset=False) -> ProgressCallback:
from ..chain.pipeline import ChainProgress
@ -118,11 +117,8 @@ class WorkerContext:
return self.callback
def on_progress(step: int, timestep: int, latents: Any):
# self.callback.step = step
self.set_progress(
step,
stages=self.callback.stages.current,
tiles=self.callback.tiles.current,
)
self.callback = ChainProgress.from_progress(on_progress)
@ -136,32 +132,37 @@ class WorkerContext:
with self.idle.get_lock():
self.idle.value = idle
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
def set_progress(self, steps: int, stages: int = None, tiles: int = None) -> None:
if self.job is None:
raise RuntimeError("no job on which to set progress")
if self.is_cancelled():
raise CancelledException("job has been cancelled")
# update current progress counters
self.steps.current = steps
if stages is not None:
self.stages.current = stages
if tiles is not None:
self.tiles.current = tiles
# TODO: result should really be part of context at this point
result = None
total_steps = 0
total_stages = 0
total_tiles = 0
if self.callback is not None:
result = self.callback.result
total_steps = self.callback.steps.total
total_stages = self.callback.stages.total
total_tiles = self.callback.tiles.total
# send progress to worker pool
logger.debug("setting progress for job %s to %s", self.job, steps)
self.last_progress = ProgressCommand(
self.job,
self.job_type,
self.device.device,
JobStatus.RUNNING,
steps=Progress(steps, total_steps),
stages=Progress(stages, total_stages),
tiles=Progress(tiles, total_tiles),
steps=self.steps,
stages=self.stages,
tiles=self.tiles,
result=result,
)
self.progress.put(
@ -169,20 +170,48 @@ class WorkerContext:
block=False,
)
def set_steps(self, current: int, total: int = 0) -> None:
if total > 0:
self.steps = Progress(current, total)
else:
self.steps.current = current
def set_stages(self, current: int, total: int = 0) -> None:
if total > 0:
self.stages = Progress(current, total)
else:
self.stages.current = current
def set_tiles(self, current: int, total: int = 0) -> None:
if total > 0:
self.tiles = Progress(current, total)
else:
self.tiles.current = current
def set_totals(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
self.steps.total = steps
self.stages.total = stages
self.tiles.total = tiles
def finish(self) -> None:
if self.job is None:
logger.warning("setting finished without an active job")
else:
logger.debug("setting finished for job %s", self.job)
result = None
if self.callback is not None:
result = self.callback.result
self.last_progress = ProgressCommand(
self.job,
self.job_type,
self.device.device,
JobStatus.SUCCESS,
steps=self.callback.steps,
stages=self.callback.stages,
tiles=self.callback.tiles,
result=self.callback.result,
steps=self.steps,
stages=self.stages,
tiles=self.tiles,
result=result,
)
self.progress.put(
self.last_progress,
@ -200,10 +229,10 @@ class WorkerContext:
self.job_type,
self.device.device,
JobStatus.FAILED,
steps=self.get_last_steps(),
stages=self.get_last_stages(),
tiles=self.get_last_tiles(),
# TODO: should this have results?
steps=self.steps,
stages=self.stages,
tiles=self.tiles,
# TODO: should this include partial results?
)
self.progress.put(
self.last_progress,