From 14ade839377d64a874f1bc9d6ba29e0bb00a8d33 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Mar 2023 11:41:17 -0500 Subject: [PATCH] fix(api): enqueue next job when previous one finishes and after recycling worker --- api/onnx_web/worker/pool.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 1b9e0fca..eae785a1 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -333,6 +333,7 @@ class DevicePoolExecutor: for device in self.devices: if device.device in needs_restart: self.create_device_worker(device) + self.next_job(device) if self.logger_worker.is_alive(): logger.debug("logger worker is running") @@ -369,7 +370,6 @@ class DevicePoolExecutor: # build and queue job job = JobCommand(key, device, fn, args, kwargs) self.pending_jobs.append(job) - self.pending[device].put(job, block=False) def status(self) -> List[Tuple[str, int, bool, bool, bool, bool]]: history = [ @@ -411,6 +411,16 @@ class DevicePoolExecutor: ) return history + def next_job(self, device: str): + for job in self.pending_jobs: + if job.device == device: + logger.debug("enqueuing job %s on device %s", job.name, device) + self.pending[device].put(job, block=False) + self.pending_jobs.remove(job) + return + + logger.trace("no pending jobs for device %s", device) + def update_job(self, progress: ProgressCommand): if progress.finished: # move from running to finished @@ -422,6 +432,9 @@ class DevicePoolExecutor: self.join_leaking() if progress.job in self.cancelled_jobs: self.cancelled_jobs.remove(progress.job) + + # enqueue the next job for this device + self.next_job(progress.device) else: # move from pending to running logger.debug(