move progress state to worker context
This commit is contained in:
parent
4f230f4111
commit
9b5d00a66a
|
@ -26,19 +26,13 @@ class ChainProgress:
|
||||||
step: int # same as steps.current, left for legacy purposes
|
step: int # same as steps.current, left for legacy purposes
|
||||||
prev: int # accumulator when step resets
|
prev: int # accumulator when step resets
|
||||||
|
|
||||||
# new progress trackers
|
# TODO: should probably be moved to worker context as well
|
||||||
steps: Progress
|
|
||||||
stages: Progress
|
|
||||||
tiles: Progress
|
|
||||||
result: Optional[StageResult]
|
result: Optional[StageResult]
|
||||||
|
|
||||||
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
def __init__(self, parent: ProgressCallback, start=0) -> None:
|
||||||
self.parent = parent
|
self.parent = parent
|
||||||
self.step = start
|
self.step = start
|
||||||
self.prev = 0
|
self.prev = 0
|
||||||
self.steps = Progress(self.step, self.prev)
|
|
||||||
self.stages = Progress(0, 0)
|
|
||||||
self.tiles = Progress(0, 0)
|
|
||||||
self.result = None
|
self.result = None
|
||||||
|
|
||||||
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
def __call__(self, step: int, timestep: int, latents: Any) -> None:
|
||||||
|
@ -47,20 +41,11 @@ class ChainProgress:
|
||||||
self.prev += self.step
|
self.prev += self.step
|
||||||
|
|
||||||
self.step = step
|
self.step = step
|
||||||
|
self.parent(self.get_total(), timestep, latents)
|
||||||
total = self.get_total()
|
|
||||||
self.steps.current = total
|
|
||||||
self.parent(total, timestep, latents)
|
|
||||||
|
|
||||||
def get_total(self) -> int:
|
def get_total(self) -> int:
|
||||||
return self.step + self.prev
|
return self.step + self.prev
|
||||||
|
|
||||||
def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
|
||||||
self.prev = steps
|
|
||||||
self.steps.total = steps
|
|
||||||
self.stages.total = stages
|
|
||||||
self.tiles.total = tiles
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_progress(cls, parent: ProgressCallback):
|
def from_progress(cls, parent: ProgressCallback):
|
||||||
start = parent.step if hasattr(parent, "step") else 0
|
start = parent.step if hasattr(parent, "step") else 0
|
||||||
|
@ -144,7 +129,7 @@ class ChainPipeline:
|
||||||
size = pipeline_kwargs["size"]
|
size = pipeline_kwargs["size"]
|
||||||
steps = self.steps(params, size)
|
steps = self.steps(params, size)
|
||||||
|
|
||||||
callback.set_total(steps, stages=len(self.stages), tiles=0)
|
worker.set_totals(steps, stages=len(self.stages), tiles=0)
|
||||||
|
|
||||||
start = monotonic()
|
start = monotonic()
|
||||||
|
|
||||||
|
@ -167,7 +152,7 @@ class ChainPipeline:
|
||||||
len(stage_sources),
|
len(stage_sources),
|
||||||
kwargs.keys(),
|
kwargs.keys(),
|
||||||
)
|
)
|
||||||
callback.stages.current = stage_i
|
worker.set_stages(stage_i)
|
||||||
|
|
||||||
per_stage_params = params
|
per_stage_params = params
|
||||||
if "params" in kwargs:
|
if "params" in kwargs:
|
||||||
|
@ -191,7 +176,7 @@ class ChainPipeline:
|
||||||
if stage_pipe.max_tile > 0:
|
if stage_pipe.max_tile > 0:
|
||||||
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
tile = min(stage_pipe.max_tile, stage_params.tile_size)
|
||||||
|
|
||||||
callback.tiles.current = 0 # reset this either way
|
worker.set_tiles(0)
|
||||||
if must_tile:
|
if must_tile:
|
||||||
logger.info(
|
logger.info(
|
||||||
"image contains sources or is larger than tile size of %s, tiling stage",
|
"image contains sources or is larger than tile size of %s, tiling stage",
|
||||||
|
@ -223,7 +208,7 @@ class ChainPipeline:
|
||||||
for j, image in enumerate(tile_result.as_image()):
|
for j, image in enumerate(tile_result.as_image()):
|
||||||
save_image(server, f"last-tile-{j}.png", image)
|
save_image(server, f"last-tile-{j}.png", image)
|
||||||
|
|
||||||
callback.tiles.current = callback.tiles.current + 1
|
worker.set_tiles(worker.tiles.current + 1)
|
||||||
|
|
||||||
return tile_result
|
return tile_result
|
||||||
except CancelledException as err:
|
except CancelledException as err:
|
||||||
|
|
|
@ -38,6 +38,13 @@ class Progress:
|
||||||
"total": self.total,
|
"total": self.total,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def complete(self) -> bool:
|
||||||
|
return self.current >= self.total
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
# TODO: what if total is also 0?
|
||||||
|
return self.current == 0
|
||||||
|
|
||||||
|
|
||||||
class ProgressCommand:
|
class ProgressCommand:
|
||||||
device: str
|
device: str
|
||||||
|
|
|
@ -30,6 +30,11 @@ class WorkerContext:
|
||||||
initial_retries: int
|
initial_retries: int
|
||||||
callback: Optional[Any]
|
callback: Optional[Any]
|
||||||
|
|
||||||
|
# progress state
|
||||||
|
steps: Progress
|
||||||
|
stages: Progress
|
||||||
|
tiles: Progress
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
@ -58,6 +63,9 @@ class WorkerContext:
|
||||||
self.retries = retries
|
self.retries = retries
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.callback = None
|
self.callback = None
|
||||||
|
self.steps = Progress(0, 0)
|
||||||
|
self.stages = Progress(0, 0)
|
||||||
|
self.tiles = Progress(0, 0)
|
||||||
|
|
||||||
def start(self, job: JobCommand) -> None:
|
def start(self, job: JobCommand) -> None:
|
||||||
# set job name and type
|
# set job name and type
|
||||||
|
@ -94,22 +102,13 @@ class WorkerContext:
|
||||||
return self.get_last_steps()
|
return self.get_last_steps()
|
||||||
|
|
||||||
def get_last_steps(self) -> Progress:
|
def get_last_steps(self) -> Progress:
|
||||||
if self.last_progress is not None:
|
return self.steps
|
||||||
return self.last_progress.steps
|
|
||||||
|
|
||||||
return Progress(0, 0)
|
|
||||||
|
|
||||||
def get_last_stages(self) -> Progress:
|
def get_last_stages(self) -> Progress:
|
||||||
if self.last_progress is not None:
|
return self.stages
|
||||||
return self.last_progress.stages
|
|
||||||
|
|
||||||
return Progress(0, 0)
|
|
||||||
|
|
||||||
def get_last_tiles(self) -> Progress:
|
def get_last_tiles(self) -> Progress:
|
||||||
if self.last_progress is not None:
|
return self.tiles
|
||||||
return self.last_progress.tiles
|
|
||||||
|
|
||||||
return Progress(0, 0)
|
|
||||||
|
|
||||||
def get_progress_callback(self, reset=False) -> ProgressCallback:
|
def get_progress_callback(self, reset=False) -> ProgressCallback:
|
||||||
from ..chain.pipeline import ChainProgress
|
from ..chain.pipeline import ChainProgress
|
||||||
|
@ -118,11 +117,8 @@ class WorkerContext:
|
||||||
return self.callback
|
return self.callback
|
||||||
|
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
# self.callback.step = step
|
|
||||||
self.set_progress(
|
self.set_progress(
|
||||||
step,
|
step,
|
||||||
stages=self.callback.stages.current,
|
|
||||||
tiles=self.callback.tiles.current,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.callback = ChainProgress.from_progress(on_progress)
|
self.callback = ChainProgress.from_progress(on_progress)
|
||||||
|
@ -136,32 +132,37 @@ class WorkerContext:
|
||||||
with self.idle.get_lock():
|
with self.idle.get_lock():
|
||||||
self.idle.value = idle
|
self.idle.value = idle
|
||||||
|
|
||||||
def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
def set_progress(self, steps: int, stages: int = None, tiles: int = None) -> None:
|
||||||
if self.job is None:
|
if self.job is None:
|
||||||
raise RuntimeError("no job on which to set progress")
|
raise RuntimeError("no job on which to set progress")
|
||||||
|
|
||||||
if self.is_cancelled():
|
if self.is_cancelled():
|
||||||
raise CancelledException("job has been cancelled")
|
raise CancelledException("job has been cancelled")
|
||||||
|
|
||||||
|
# update current progress counters
|
||||||
|
self.steps.current = steps
|
||||||
|
|
||||||
|
if stages is not None:
|
||||||
|
self.stages.current = stages
|
||||||
|
|
||||||
|
if tiles is not None:
|
||||||
|
self.tiles.current = tiles
|
||||||
|
|
||||||
|
# TODO: result should really be part of context at this point
|
||||||
result = None
|
result = None
|
||||||
total_steps = 0
|
|
||||||
total_stages = 0
|
|
||||||
total_tiles = 0
|
|
||||||
if self.callback is not None:
|
if self.callback is not None:
|
||||||
result = self.callback.result
|
result = self.callback.result
|
||||||
total_steps = self.callback.steps.total
|
|
||||||
total_stages = self.callback.stages.total
|
|
||||||
total_tiles = self.callback.tiles.total
|
|
||||||
|
|
||||||
|
# send progress to worker pool
|
||||||
logger.debug("setting progress for job %s to %s", self.job, steps)
|
logger.debug("setting progress for job %s to %s", self.job, steps)
|
||||||
self.last_progress = ProgressCommand(
|
self.last_progress = ProgressCommand(
|
||||||
self.job,
|
self.job,
|
||||||
self.job_type,
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
JobStatus.RUNNING,
|
JobStatus.RUNNING,
|
||||||
steps=Progress(steps, total_steps),
|
steps=self.steps,
|
||||||
stages=Progress(stages, total_stages),
|
stages=self.stages,
|
||||||
tiles=Progress(tiles, total_tiles),
|
tiles=self.tiles,
|
||||||
result=result,
|
result=result,
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
|
@ -169,20 +170,48 @@ class WorkerContext:
|
||||||
block=False,
|
block=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_steps(self, current: int, total: int = 0) -> None:
|
||||||
|
if total > 0:
|
||||||
|
self.steps = Progress(current, total)
|
||||||
|
else:
|
||||||
|
self.steps.current = current
|
||||||
|
|
||||||
|
def set_stages(self, current: int, total: int = 0) -> None:
|
||||||
|
if total > 0:
|
||||||
|
self.stages = Progress(current, total)
|
||||||
|
else:
|
||||||
|
self.stages.current = current
|
||||||
|
|
||||||
|
def set_tiles(self, current: int, total: int = 0) -> None:
|
||||||
|
if total > 0:
|
||||||
|
self.tiles = Progress(current, total)
|
||||||
|
else:
|
||||||
|
self.tiles.current = current
|
||||||
|
|
||||||
|
def set_totals(self, steps: int, stages: int = 0, tiles: int = 0) -> None:
|
||||||
|
self.steps.total = steps
|
||||||
|
self.stages.total = stages
|
||||||
|
self.tiles.total = tiles
|
||||||
|
|
||||||
def finish(self) -> None:
|
def finish(self) -> None:
|
||||||
if self.job is None:
|
if self.job is None:
|
||||||
logger.warning("setting finished without an active job")
|
logger.warning("setting finished without an active job")
|
||||||
else:
|
else:
|
||||||
logger.debug("setting finished for job %s", self.job)
|
logger.debug("setting finished for job %s", self.job)
|
||||||
|
|
||||||
|
result = None
|
||||||
|
if self.callback is not None:
|
||||||
|
result = self.callback.result
|
||||||
|
|
||||||
self.last_progress = ProgressCommand(
|
self.last_progress = ProgressCommand(
|
||||||
self.job,
|
self.job,
|
||||||
self.job_type,
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
JobStatus.SUCCESS,
|
JobStatus.SUCCESS,
|
||||||
steps=self.callback.steps,
|
steps=self.steps,
|
||||||
stages=self.callback.stages,
|
stages=self.stages,
|
||||||
tiles=self.callback.tiles,
|
tiles=self.tiles,
|
||||||
result=self.callback.result,
|
result=result,
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
|
@ -200,10 +229,10 @@ class WorkerContext:
|
||||||
self.job_type,
|
self.job_type,
|
||||||
self.device.device,
|
self.device.device,
|
||||||
JobStatus.FAILED,
|
JobStatus.FAILED,
|
||||||
steps=self.get_last_steps(),
|
steps=self.steps,
|
||||||
stages=self.get_last_stages(),
|
stages=self.stages,
|
||||||
tiles=self.get_last_tiles(),
|
tiles=self.tiles,
|
||||||
# TODO: should this have results?
|
# TODO: should this include partial results?
|
||||||
)
|
)
|
||||||
self.progress.put(
|
self.progress.put(
|
||||||
self.last_progress,
|
self.last_progress,
|
||||||
|
|
Loading…
Reference in New Issue