From 9b5d00a66a19e3236d3245de84f3fd802ffca5e0 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 5 Jan 2024 20:12:41 -0600 Subject: [PATCH] move progress state to worker context --- api/onnx_web/chain/pipeline.py | 27 +++------- api/onnx_web/worker/command.py | 7 +++ api/onnx_web/worker/context.py | 95 ++++++++++++++++++++++------------ 3 files changed, 75 insertions(+), 54 deletions(-) diff --git a/api/onnx_web/chain/pipeline.py b/api/onnx_web/chain/pipeline.py index 6578bf82..1a6a597b 100644 --- a/api/onnx_web/chain/pipeline.py +++ b/api/onnx_web/chain/pipeline.py @@ -26,19 +26,13 @@ class ChainProgress: step: int # same as steps.current, left for legacy purposes prev: int # accumulator when step resets - # new progress trackers - steps: Progress - stages: Progress - tiles: Progress + # TODO: should probably be moved to worker context as well result: Optional[StageResult] def __init__(self, parent: ProgressCallback, start=0) -> None: self.parent = parent self.step = start self.prev = 0 - self.steps = Progress(self.step, self.prev) - self.stages = Progress(0, 0) - self.tiles = Progress(0, 0) self.result = None def __call__(self, step: int, timestep: int, latents: Any) -> None: @@ -47,20 +41,11 @@ class ChainProgress: self.prev += self.step self.step = step - - total = self.get_total() - self.steps.current = total - self.parent(total, timestep, latents) + self.parent(self.get_total(), timestep, latents) def get_total(self) -> int: return self.step + self.prev - def set_total(self, steps: int, stages: int = 0, tiles: int = 0) -> None: - self.prev = steps - self.steps.total = steps - self.stages.total = stages - self.tiles.total = tiles - @classmethod def from_progress(cls, parent: ProgressCallback): start = parent.step if hasattr(parent, "step") else 0 @@ -144,7 +129,7 @@ class ChainPipeline: size = pipeline_kwargs["size"] steps = self.steps(params, size) - callback.set_total(steps, stages=len(self.stages), tiles=0) + worker.set_totals(steps, stages=len(self.stages), tiles=0) start = monotonic() @@ -167,7 +152,7 @@ class ChainPipeline: len(stage_sources), kwargs.keys(), ) - callback.stages.current = stage_i + worker.set_stages(stage_i) per_stage_params = params if "params" in kwargs: @@ -191,7 +176,7 @@ class ChainPipeline: if stage_pipe.max_tile > 0: tile = min(stage_pipe.max_tile, stage_params.tile_size) - callback.tiles.current = 0 # reset this either way + worker.set_tiles(0) if must_tile: logger.info( "image contains sources or is larger than tile size of %s, tiling stage", @@ -223,7 +208,7 @@ class ChainPipeline: for j, image in enumerate(tile_result.as_image()): save_image(server, f"last-tile-{j}.png", image) - callback.tiles.current = callback.tiles.current + 1 + worker.set_tiles(worker.tiles.current + 1) return tile_result except CancelledException as err: diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index 31f38ee0..1ba72a2b 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -38,6 +38,13 @@ class Progress: "total": self.total, } + def complete(self) -> bool: + return self.current >= self.total + + def empty(self) -> bool: + # TODO: what if total is also 0? + return self.current == 0 + class ProgressCommand: device: str diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index a5a6325b..da68ddc6 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -30,6 +30,11 @@ class WorkerContext: initial_retries: int callback: Optional[Any] + # progress state + steps: Progress + stages: Progress + tiles: Progress + def __init__( self, name: str, @@ -58,6 +63,9 @@ class WorkerContext: self.retries = retries self.timeout = timeout self.callback = None + self.steps = Progress(0, 0) + self.stages = Progress(0, 0) + self.tiles = Progress(0, 0) def start(self, job: JobCommand) -> None: # set job name and type @@ -94,22 +102,13 @@ class WorkerContext: return self.get_last_steps() def get_last_steps(self) -> Progress: - if self.last_progress is not None: - return self.last_progress.steps - - return Progress(0, 0) + return self.steps def get_last_stages(self) -> Progress: - if self.last_progress is not None: - return self.last_progress.stages - - return Progress(0, 0) + return self.stages def get_last_tiles(self) -> Progress: - if self.last_progress is not None: - return self.last_progress.tiles - - return Progress(0, 0) + return self.tiles def get_progress_callback(self, reset=False) -> ProgressCallback: from ..chain.pipeline import ChainProgress @@ -118,11 +117,8 @@ class WorkerContext: return self.callback def on_progress(step: int, timestep: int, latents: Any): - # self.callback.step = step self.set_progress( step, - stages=self.callback.stages.current, - tiles=self.callback.tiles.current, ) self.callback = ChainProgress.from_progress(on_progress) @@ -136,32 +132,37 @@ class WorkerContext: with self.idle.get_lock(): self.idle.value = idle - def set_progress(self, steps: int, stages: int = 0, tiles: int = 0) -> None: + def set_progress(self, steps: int, stages: int = None, tiles: int = None) -> None: if self.job is None: raise RuntimeError("no job on which to set progress") if self.is_cancelled(): raise CancelledException("job has been cancelled") + # update current progress counters + self.steps.current = steps + + if stages is not None: + self.stages.current = stages + + if tiles is not None: + self.tiles.current = tiles + + # TODO: result should really be part of context at this point result = None - total_steps = 0 - total_stages = 0 - total_tiles = 0 if self.callback is not None: result = self.callback.result - total_steps = self.callback.steps.total - total_stages = self.callback.stages.total - total_tiles = self.callback.tiles.total + # send progress to worker pool logger.debug("setting progress for job %s to %s", self.job, steps) self.last_progress = ProgressCommand( self.job, self.job_type, self.device.device, JobStatus.RUNNING, - steps=Progress(steps, total_steps), - stages=Progress(stages, total_stages), - tiles=Progress(tiles, total_tiles), + steps=self.steps, + stages=self.stages, + tiles=self.tiles, result=result, ) self.progress.put( @@ -169,20 +170,48 @@ class WorkerContext: block=False, ) + def set_steps(self, current: int, total: int = 0) -> None: + if total > 0: + self.steps = Progress(current, total) + else: + self.steps.current = current + + def set_stages(self, current: int, total: int = 0) -> None: + if total > 0: + self.stages = Progress(current, total) + else: + self.stages.current = current + + def set_tiles(self, current: int, total: int = 0) -> None: + if total > 0: + self.tiles = Progress(current, total) + else: + self.tiles.current = current + + def set_totals(self, steps: int, stages: int = 0, tiles: int = 0) -> None: + self.steps.total = steps + self.stages.total = stages + self.tiles.total = tiles + def finish(self) -> None: if self.job is None: logger.warning("setting finished without an active job") else: logger.debug("setting finished for job %s", self.job) + + result = None + if self.callback is not None: + result = self.callback.result + self.last_progress = ProgressCommand( self.job, self.job_type, self.device.device, JobStatus.SUCCESS, - steps=self.callback.steps, - stages=self.callback.stages, - tiles=self.callback.tiles, - result=self.callback.result, + steps=self.steps, + stages=self.stages, + tiles=self.tiles, + result=result, ) self.progress.put( self.last_progress, @@ -200,10 +229,10 @@ class WorkerContext: self.job_type, self.device.device, JobStatus.FAILED, - steps=self.get_last_steps(), - stages=self.get_last_stages(), - tiles=self.get_last_tiles(), - # TODO: should this have results? + steps=self.steps, + stages=self.stages, + tiles=self.tiles, + # TODO: should this include partial results? ) self.progress.put( self.last_progress,