diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index c79a981f..af99611b 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -1,7 +1,7 @@ from collections import Counter from logging import getLogger from queue import Empty -from threading import Thread +from threading import Thread, Lock from typing import Callable, Dict, List, Optional, Tuple from torch.multiprocessing import Process, Queue, Value @@ -37,6 +37,7 @@ class DevicePoolExecutor: logs: "Queue[str]" progress: "Queue[ProgressCommand]" + rlock: Lock def __init__( self, @@ -66,6 +67,7 @@ class DevicePoolExecutor: self.logs = Queue(self.max_pending_per_worker) self.progress = Queue(self.max_pending_per_worker) + self.rlock = Lock() # TODO: these should be part of a start method self.create_logger_worker() @@ -223,35 +225,36 @@ class DevicePoolExecutor: def join(self): logger.info("stopping worker pool") - logger.debug("closing queues") - self.logs.close() - self.progress.close() - for queue in self.pending.values(): - queue.close() + with self.rlock: + logger.debug("closing queues") + self.logs.close() + self.progress.close() + for queue in self.pending.values(): + queue.close() - self.pending.clear() - self.join_leaking() + self.pending.clear() + self.join_leaking() - logger.debug("stopping device workers") - for device, worker in self.workers.items(): - if worker.is_alive(): - logger.debug("stopping worker %s for device %s", worker.pid, device) - worker.join(self.join_timeout) + logger.debug("stopping device workers") + for device, worker in self.workers.items(): if worker.is_alive(): - logger.warning( - "worker %s for device %s could not be stopped in time", - worker.pid, - device, - ) - self.leaking.append((device, worker)) - else: - logger.debug("worker for device %s has died", device) + logger.debug("stopping worker %s for device %s", worker.pid, device) + worker.join(self.join_timeout) + if worker.is_alive(): + logger.warning( + "worker %s for device %s could not be stopped in time", + worker.pid, + device, + ) + self.leaking.append((device, worker)) + else: + logger.debug("worker for device %s has died", device) - for name, thread in self.threads.items(): - logger.debug("stopping worker %s for thread %s", thread.ident, name) - thread.join(self.join_timeout) + for name, thread in self.threads.items(): + logger.debug("stopping worker %s for thread %s", thread.ident, name) + thread.join(self.join_timeout) - logger.debug("worker pool stopped") + logger.debug("worker pool stopped") def join_leaking(self): if len(self.leaking) > 0: @@ -272,56 +275,62 @@ class DevicePoolExecutor: def recycle(self): logger.debug("recycling worker pool") - self.join_leaking() - needs_restart = [] + with self.rlock: + self.join_leaking() - for device, worker in self.workers.items(): - jobs = self.total_jobs.get(device, 0) - if not worker.is_alive(): - logger.warning("worker for device %s has died", device) - needs_restart.append(device) - elif jobs > self.max_jobs_per_worker: - logger.info( - "shutting down worker for device %s after %s jobs", device, jobs - ) - worker.join(self.join_timeout) - if worker.is_alive(): - logger.warning( - "worker %s for device %s could not be recycled in time", + needs_restart = [] + + for device, worker in self.workers.items(): + jobs = self.total_jobs.get(device, 0) + if not worker.is_alive(): + logger.warning("worker for device %s has died", device) + needs_restart.append(device) + elif jobs > self.max_jobs_per_worker: + logger.info( + "shutting down worker for device %s after %s jobs", device, jobs + ) + worker.join(self.join_timeout) + if worker.is_alive(): + logger.warning( + "worker %s for device %s could not be recycled in time", + worker.pid, + device, + ) + self.leaking.append((device, worker)) + else: + del worker + + self.workers[device] = None + needs_restart.append(device) + else: + logger.debug( + "worker %s for device %s does not need to be recycled", worker.pid, device, ) - self.leaking.append((device, worker)) - else: - del worker - self.workers[device] = None - needs_restart.append(device) + if len(needs_restart) > 0: + logger.info("starting new workers") + + for device in self.devices: + if device.device in needs_restart: + self.create_device_worker(device) + self.total_jobs[device.device] = 0 + + if self.threads["logger"].is_alive(): + logger.debug("logger worker is running") else: - logger.debug( - "worker %s for device %s does not need to be recycled", - worker.pid, - device, - ) + logger.warning("restarting logger worker") + self.create_logger_worker() - if len(needs_restart) > 0: - logger.debug("starting new workers") + if self.threads["progress"].is_alive(): + logger.debug("progress worker is running") + else: + logger.warning("restarting progress worker") + self.create_progress_worker() - for device in self.devices: - if device.device in needs_restart: - self.create_device_worker(device) - self.total_jobs[device.device] = 0 - - if not self.threads["logger"].is_alive(): - logger.warning("restarting crashed logger worker") - self.create_logger_worker() - - if not self.threads["progress"].is_alive(): - logger.warning("restarting crashed progress worker") - self.create_progress_worker() - - logger.debug("worker pool recycled") + logger.debug("worker pool recycled") def submit( self,