From 588c8c7fdb4007f8d83ece40ebdac06f94089298 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Mar 2023 15:32:49 -0500 Subject: [PATCH] fix(api): track last progress within worker --- api/onnx_web/server/api.py | 4 +- api/onnx_web/worker/context.py | 76 +++++++++++++++++++--------------- api/onnx_web/worker/pool.py | 2 +- api/onnx_web/worker/worker.py | 4 +- 4 files changed, 48 insertions(+), 38 deletions(-) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 5783ddc8..6619097d 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -457,7 +457,9 @@ def ready(context: ServerContext, pool: DevicePoolExecutor): if path.exists(output): return ready_reply(True) else: - return ready_reply(True, error=True) # is a missing image really an error? yes will display the retry button + return ready_reply( + True, error=True + ) # is a missing image really an error? yes will display the retry button return ready_reply( progress.finished, diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index fc65c8a2..fd497434 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -1,6 +1,6 @@ from logging import getLogger from os import getpid -from typing import Any, Callable +from typing import Any, Callable, Optional from torch.multiprocessing import Queue, Value @@ -17,8 +17,9 @@ class WorkerContext: cancel: "Value[bool]" job: str pending: "Queue[JobCommand]" - current: "Value[int]" + active_pid: "Value[int]" progress: "Queue[ProgressCommand]" + last_progress: Optional[ProgressCommand] def __init__( self, @@ -28,7 +29,7 @@ class WorkerContext: logs: "Queue[str]", pending: "Queue[JobCommand]", progress: "Queue[ProgressCommand]", - current: "Value[int]", + active_pid: "Value[int]", ): self.job = job self.device = device @@ -36,17 +37,17 @@ class WorkerContext: self.progress = progress self.logs = logs self.pending = pending - self.current = current + self.active_pid = active_pid def is_cancelled(self) -> bool: return self.cancel.value - def is_current(self) -> bool: - return self.get_current() == getpid() + def is_active(self) -> bool: + return self.get_active() == getpid() - def get_current(self) -> int: - with self.current.get_lock(): - return self.current.value + def get_active(self) -> int: + with self.active_pid.get_lock(): + return self.active_pid.value def get_device(self) -> DeviceParams: """ @@ -55,7 +56,10 @@ class WorkerContext: return self.device def get_progress(self) -> int: - return self.progress.value + if self.last_progress is not None: + return self.last_progress.progress + + return 0 def get_progress_callback(self) -> ProgressCallback: def on_progress(step: int, timestep: int, latents: Any): @@ -73,44 +77,48 @@ class WorkerContext: raise RuntimeError("job has been cancelled") else: logger.debug("setting progress for job %s to %s", self.job, progress) + self.last_progress = ProgressCommand( + self.job, + self.device.device, + False, + progress, + self.is_cancelled(), + False, + ) + self.progress.put( - ProgressCommand( - self.job, - self.device.device, - False, - progress, - self.is_cancelled(), - False, - ), + self.last_progress, block=False, ) def set_finished(self) -> None: logger.debug("setting finished for job %s", self.job) + self.last_progress = ProgressCommand( + self.job, + self.device.device, + True, + self.get_progress(), + self.is_cancelled(), + False, + ) self.progress.put( - ProgressCommand( - self.job, - self.device.device, - True, - self.get_progress(), - self.is_cancelled(), - False, - ), + self.last_progress, block=False, ) def set_failed(self) -> None: logger.warning("setting failure for job %s", self.job) try: + self.last_progress = ProgressCommand( + self.job, + self.device.device, + True, + self.get_progress(), + self.is_cancelled(), + True, + ) self.progress.put( - ProgressCommand( - self.job, - self.device.device, - True, - self.get_progress(), - self.is_cancelled(), - True, - ), + self.last_progress, block=False, ) except Exception: diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 6f934bda..dc4b95fc 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -99,7 +99,7 @@ class DevicePoolExecutor: progress=self.progress, logs=self.logs, pending=pending, - current=current, + active_pid=current, ) self.context[name] = context worker = Process( diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 0d320128..6afd60b3 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -29,11 +29,11 @@ def worker_main(context: WorkerContext, server: ServerContext): while True: try: - if not context.is_current(): + if not context.is_active(): logger.warning( "worker %s has been replaced by %s, exiting", getpid(), - context.get_current(), + context.get_active(), ) exit(EXIT_REPLACED)