diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 24d96c6a..411145a2 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -1,4 +1,5 @@ from logging import getLogger +from os import getpid from typing import Any, Callable, Tuple from torch.multiprocessing import Queue, Value @@ -15,6 +16,7 @@ class WorkerContext: cancel: "Value[bool]" job: str pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]" + current: "Value[int]" progress: "Queue[Tuple[str, str, int]]" def __init__( @@ -26,6 +28,7 @@ class WorkerContext: pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]", progress: "Queue[Tuple[str, str, int]]", finished: "Queue[Tuple[str, str]]", + current: "Value[int]", ): self.job = job self.device = device @@ -34,10 +37,17 @@ class WorkerContext: self.finished = finished self.logs = logs self.pending = pending + self.current = current def is_cancelled(self) -> bool: return self.cancel.value + def is_current(self) -> bool: + if self.current.value > 0: + return self.current.value == getpid() + + return True + def get_device(self) -> DeviceParams: """ Get the device assigned to this job. diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 907658df..3a7cf0b0 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -23,6 +23,7 @@ class DevicePoolExecutor: leaking: List[Tuple[str, Process]] context: Dict[str, WorkerContext] # Device -> Context + current: Dict[str, "Value[int]"] pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"] threads: Dict[str, Thread] workers: Dict[str, Process] @@ -52,6 +53,7 @@ class DevicePoolExecutor: self.leaking = [] self.context = {} + self.current = {} self.pending = {} self.threads = {} self.workers = {} @@ -85,6 +87,14 @@ class DevicePoolExecutor: pending = Queue(self.max_pending_per_worker) self.pending[name] = pending + if name in self.current: + logger.debug("using existing current worker value") + current = self.current[name] + else: + logger.debug("creating new current worker value") + current = Value("L", 0) + self.current[name] = current + context = WorkerContext( name, device, @@ -93,16 +103,19 @@ class DevicePoolExecutor: finished=self.finished, logs=self.logs, pending=pending, + current=current, ) self.context[name] = context - self.workers[name] = Process( + worker = Process( name=f"onnx-web worker: {name}", target=worker_main, args=(context, self.server), ) logger.debug("starting worker for device %s", device) - self.workers[name].start() + worker.start() + self.workers[name] = worker + current.value = worker.pid def create_logger_worker(self) -> None: def logger_worker(logs: Queue): diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index e53a060c..640c6fe0 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -24,6 +24,10 @@ def worker_main(context: WorkerContext, server: ServerContext): while True: try: + if not context.is_current(): + logger.warning("worker has been replaced, exiting") + exit(3) + name, fn, args, kwargs = context.pending.get(timeout=1.0) logger.info("worker for %s got job: %s", context.device.device, name)