diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 77041d03..d9b838da 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -21,6 +21,7 @@ class DevicePoolExecutor: max_pending_per_worker: int join_timeout: float + leaking: List[Tuple[str, Process]] context: Dict[str, WorkerContext] # Device -> Context pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"] threads: Dict[str, Thread] @@ -49,6 +50,7 @@ class DevicePoolExecutor: self.max_pending_per_worker = max_pending_per_worker self.join_timeout = join_timeout + self.leaking = [] self.context = {} self.pending = {} self.threads = {} @@ -133,7 +135,7 @@ class DevicePoolExecutor: while True: try: job, device, value = progress.get(timeout=(self.join_timeout / 2)) - logger.info("progress update for job: %s to %s", job, value) + logger.debug("progress update for job: %s to %s", job, value) self.active_jobs[job] = (device, value) if job in self.cancelled_jobs: logger.debug( @@ -270,6 +272,7 @@ class DevicePoolExecutor: logger.warning( "worker for device %s could not be stopped in time", device ) + self.leaking.append((device, worker)) else: logger.debug("worker for device %s has died", device) @@ -281,25 +284,38 @@ class DevicePoolExecutor: def recycle(self): logger.debug("recycling worker pool") + + if len(self.leaking) > 0: + logger.warning("cleaning up %s leaking workers", len(self.leaking)) + for device, worker in self.leaking: + logger.debug("shutting down worker for device %s", device) + worker.join(self.join_timeout) + if worker.is_alive(): + logger.error("leaking worker for device %s could not be shut down", device) + + self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()] + needs_restart = [] - for device, proc in self.workers.items(): + for device, worker in self.workers.items(): jobs = self.total_jobs.get(device, 0) - if not proc.is_alive(): + if not worker.is_alive(): logger.warning("worker for device %s has died", device) needs_restart.append(device) elif jobs > self.max_jobs_per_worker: logger.info( "shutting down worker for device %s after %s jobs", device, jobs ) - proc.join(self.join_timeout) - if proc.is_alive(): + worker.join(self.join_timeout) + if worker.is_alive(): logger.warning( "worker for device %s could not be recycled in time", device ) + self.leaking.append((device, worker)) + else: + del worker self.workers[device] = None - del proc needs_restart.append(device) else: logger.debug(