From 2d2283e1ebb7381876b5d0a35ef2a9bb78865c3e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Mar 2023 15:31:16 -0500 Subject: [PATCH] fix(api): attempt to read progress updates from recycled workers --- api/onnx_web/worker/pool.py | 39 +++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index a065d829..c8e96c65 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -26,7 +26,7 @@ class DevicePoolExecutor: progress_interval: float recycle_interval: float - leaking: List[Tuple[str, Process]] + leaking: List[Tuple[str, Process, WorkerContext]] context: Dict[str, WorkerContext] # Device -> Context current: Dict[str, "Value[int]"] # Device -> pid pending: Dict[str, "Queue[JobCommand]"] @@ -256,7 +256,7 @@ class DevicePoolExecutor: worker.pid, device, ) - self.leaking.append((device, worker)) + self.leak_worker(device) else: logger.debug("worker for device %s has died", device) @@ -273,10 +273,9 @@ class DevicePoolExecutor: def join_leaking(self): if len(self.leaking) > 0: - logger.warning("cleaning up %s leaking workers", len(self.leaking)) - for device, worker in self.leaking: - logger.debug( - "shutting down worker %s for device %s", worker.pid, device + for device, worker, _context in self.leaking: + logger.warning( + "shutting down leaking worker %s for device %s", worker.pid, device ) worker.join(self.join_timeout) if worker.is_alive(): @@ -312,7 +311,7 @@ class DevicePoolExecutor: worker.pid, device, ) - self.leaking.append((device, worker)) + self.leak_worker(device) else: del worker @@ -463,6 +462,11 @@ class DevicePoolExecutor: ) self.context[progress.device].set_cancel() + def leak_worker(self, device: str): + context = self.context[device] + worker = self.workers[device] + self.leaking.append((device, worker, context)) + def health_main(pool: DevicePoolExecutor): logger.trace("checking in from health worker thread") @@ -494,10 +498,25 @@ def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"): logger.exception("error in log worker") -def progress_main( - pool: DevicePoolExecutor -): +def progress_main(pool: DevicePoolExecutor): logger.trace("checking in from progress worker thread") + + for device, _worker, context in pool.leaking: + # whether the worker is alive or not, try to clear its queues + try: + progress = context.progress.get_nowait() + while progress is not None: + pool.update_job(progress) + progress = context.progress.get_nowait() + except Empty: + logger.trace("empty queue in leaking worker for device %s", device) + pass + except ValueError as e: + logger.debug("value error in leaking worker for device %s: %s", device, e) + break + except Exception: + logger.exception("error in leaking worker for device %s", device) + for device, queue in pool.progress.items(): try: progress = queue.get_nowait()