diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index c16d19ef..b9bc16b5 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -37,7 +37,9 @@ def pipeline_from_request( lpw = get_not_empty(request.args, "lpw", "false") == "true" model = get_not_empty(request.args, "model", get_config_value("model")) model_path = get_model_path(context, model) - scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys()) + scheduler = get_from_list( + request.args, "scheduler", list(pipeline_schedulers.keys()) + ) if scheduler is None: scheduler = get_config_value("scheduler") diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 38285df2..77041d03 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -29,7 +29,7 @@ class DevicePoolExecutor: active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus] cancelled_jobs: List[str] finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus] - total_jobs: int + total_jobs: Dict[str, int] # Device -> job count logs: "Queue" progress: "Queue[Tuple[str, str, int]]" @@ -57,12 +57,13 @@ class DevicePoolExecutor: self.active_jobs = {} self.cancelled_jobs = [] self.finished_jobs = [] - self.total_jobs = 0 # TODO: turn this into a Dict per-worker + self.total_jobs = {} self.logs = Queue(self.max_pending_per_worker) self.progress = Queue(self.max_pending_per_worker) self.finished = Queue(self.max_pending_per_worker) + # TODO: these should be part of a start method self.create_logger_worker() self.create_progress_worker() self.create_finished_worker() @@ -70,9 +71,6 @@ class DevicePoolExecutor: for device in devices: self.create_device_worker(device) - logger.debug("testing log worker") - self.logs.put("testing") - def create_device_worker(self, device: DeviceParams) -> None: name = device.device @@ -268,7 +266,10 @@ class DevicePoolExecutor: if worker.is_alive(): logger.debug("stopping worker for device %s", device) worker.join(self.join_timeout) - # worker.terminate() + if worker.is_alive(): + logger.warning( + "worker for device %s could not be stopped in time", device + ) else: logger.debug("worker for device %s has died", device) @@ -276,24 +277,43 @@ class DevicePoolExecutor: logger.debug("stopping worker thread: %s", name) thread.join(self.join_timeout) - logger.debug("worker pool fully joined") + logger.debug("worker pool stopped") def recycle(self): - for name, proc in self.workers.items(): - if proc.is_alive(): - logger.debug("shutting down worker for device %s", name) + logger.debug("recycling worker pool") + needs_restart = [] + + for device, proc in self.workers.items(): + jobs = self.total_jobs.get(device, 0) + if not proc.is_alive(): + logger.warning("worker for device %s has died", device) + needs_restart.append(device) + elif jobs > self.max_jobs_per_worker: + logger.info( + "shutting down worker for device %s after %s jobs", device, jobs + ) proc.join(self.join_timeout) - # proc.terminate() + if proc.is_alive(): + logger.warning( + "worker for device %s could not be recycled in time", device + ) + + self.workers[device] = None + del proc + needs_restart.append(device) else: - logger.warning("worker for device %s has died", name) + logger.debug( + "worker for device %s does not need to be recycled", device + ) - self.workers[name] = None - del proc - - logger.info("starting new workers") + logger.debug("starting new workers") for device in self.devices: - self.create_device_worker(device) + if device.device in needs_restart: + self.create_device_worker(device) + self.total_jobs[device.device] = 0 + + logger.debug("worker pool recycled") def submit( self, @@ -304,12 +324,6 @@ class DevicePoolExecutor: needs_device: Optional[DeviceParams] = None, **kwargs, ) -> None: - self.total_jobs += 1 - logger.debug("pool job count: %s", self.total_jobs) - if self.total_jobs > self.max_jobs_per_worker: - self.recycle() - self.total_jobs = 0 - device_idx = self.get_next_device(needs_device=needs_device) logger.info( "assigning job %s to device %s: %s", @@ -319,6 +333,15 @@ class DevicePoolExecutor: ) device = self.devices[device_idx].device + + if device in self.total_jobs: + self.total_jobs[device] += 1 + else: + self.total_jobs[device] = 1 + + logger.debug("device job count: %s", self.total_jobs[device]) + self.recycle() + self.pending[device].put((key, fn, args, kwargs), block=False) def status(self) -> List[Tuple[str, int, bool, bool]]: