1
0
Fork 0

fix(api): wait for worker to become idle before enqueueing next job (#286)

This commit is contained in:
Sean Sube 2023-04-15 20:37:53 -05:00
parent 17e7b6aff2
commit cfdd926fff
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 18 additions and 5 deletions

View File

@ -20,6 +20,7 @@ class WorkerContext:
active_pid: "Value[int]"
progress: "Queue[ProgressCommand]"
last_progress: Optional[ProgressCommand]
idle: "Value[bool]"
def __init__(
self,
@ -30,6 +31,7 @@ class WorkerContext:
pending: "Queue[JobCommand]",
progress: "Queue[ProgressCommand]",
active_pid: "Value[int]",
idle: "Value[bool]",
):
self.job = job
self.device = device
@ -39,16 +41,21 @@ class WorkerContext:
self.pending = pending
self.active_pid = active_pid
self.last_progress = None
self.idle = idle
def start(self, job: str) -> None:
self.job = job
self.set_cancel(cancel=False)
self.set_idle(idle=False)
def is_active(self) -> bool:
return self.get_active() == getpid()
def is_cancelled(self) -> bool:
return self.cancel.value
def is_active(self) -> bool:
return self.get_active() == getpid()
def is_idle(self) -> bool:
return self.idle.value
def get_active(self) -> int:
with self.active_pid.get_lock():
@ -77,6 +84,10 @@ class WorkerContext:
with self.cancel.get_lock():
self.cancel.value = cancel
def set_idle(self, idle: bool = True) -> None:
with self.idle.get_lock():
self.idle.value = idle
def set_progress(self, progress: int) -> None:
if self.is_cancelled():
raise RuntimeError("job has been cancelled")

View File

@ -114,6 +114,7 @@ class DevicePoolExecutor:
logs=self.logs,
pending=self.pending[name],
active_pid=current,
idle=Value("B", False),
)
self.context[name] = context
@ -567,7 +568,7 @@ def progress_main(pool: DevicePoolExecutor):
except Exception:
logger.exception("error in progress worker for device %s", device)
for device, queue in pool.pending.items():
if queue.empty():
for device, context in pool.context.items():
if context.is_idle():
logger.trace("enqueueing next job for idle worker")
pool.next_job(device)

View File

@ -63,7 +63,8 @@ def worker_main(worker: WorkerContext, server: ServerContext):
logger.info("job succeeded: %s", job.name)
worker.finish()
except Empty:
pass
logger.trace("worker reached end of queue, setting idle flag")
worker.set_idle()
except KeyboardInterrupt:
logger.info("worker got keyboard interrupt")
worker.fail()