From 2327b2402217208f5be6c3ee109bbb3c5a00ee2f Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 17:35:31 -0600 Subject: [PATCH] join all threads --- api/onnx_web/worker/context.py | 12 ++++++------ api/onnx_web/worker/pool.py | 8 +++----- api/onnx_web/worker/worker.py | 2 +- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 1ef564a7..a69e28a2 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -13,13 +13,13 @@ ProgressCallback = Callable[[int, int, Any], None] class WorkerContext: cancel: "Value[bool]" = None - key: str = None + job: str = None pending: "Queue[Tuple[Callable, Any, Any]]" = None progress: "Value[int]" = None def __init__( self, - key: str, + job: str, device: DeviceParams, cancel: "Value[bool]" = None, logs: "Queue[str]" = None, @@ -27,7 +27,7 @@ class WorkerContext: progress: "Queue[Tuple[str, int]]" = None, finished: "Queue[str]" = None, ): - self.key = key + self.job = job self.device = device self.cancel = cancel self.progress = progress @@ -53,7 +53,7 @@ class WorkerContext: if self.is_cancelled(): raise RuntimeError("job has been cancelled") else: - logger.debug("setting progress for job %s to %s", self.key, step) + logger.debug("setting progress for job %s to %s", self.job, step) self.set_progress(step) return on_progress @@ -63,10 +63,10 @@ class WorkerContext: self.cancel.value = cancel def set_progress(self, progress: int) -> None: - self.progress.put((self.key, self.device.device, progress)) + self.progress.put((self.job, self.device.device, progress)) def set_finished(self) -> None: - self.finished.put((self.key, self.device.device)) + self.finished.put((self.job, self.device.device)) def clear_flags(self) -> None: self.set_cancel(False) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index dced875d..522533d7 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -200,16 +200,14 @@ class DevicePoolExecutor: return (False, progress) def join(self): - self.progress_thread.join(self.join_timeout) - self.finished_thread.join(self.join_timeout) - for device, worker in self.workers.items(): if worker.is_alive(): logger.info("stopping worker for device %s", device) worker.join(self.join_timeout) - if self.logger.is_alive(): - self.logger.join(self.join_timeout) + for name, thread in self.threads.items(): + logger.info("stopping worker thread: %s", name) + thread.join(self.join_timeout) def recycle(self): for name, proc in self.workers.items(): diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 9518f948..d6e2d0c9 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -25,7 +25,7 @@ def worker_main(context: WorkerContext, server: ServerContext): name = args[3][0] try: - context.key = name # TODO: hax + context.job = name # TODO: hax context.clear_flags() logger.info("starting job: %s", name) fn(context, *args, **kwargs)