diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 44ecc824..c79a981f 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -1,5 +1,4 @@ from collections import Counter -from functools import partial from logging import getLogger from queue import Empty from threading import Thread @@ -107,7 +106,7 @@ class DevicePoolExecutor: self.context[name] = context worker = Process( name=f"onnx-web worker: {name}", - target=partial(worker_main, self), + target=worker_main, args=(context, self.server), ) @@ -118,7 +117,13 @@ class DevicePoolExecutor: def create_logger_worker(self) -> None: logger_thread = Thread( - name="onnx-web logger", target=logger_main, args=(self.logs,), daemon=True + name="onnx-web logger", + target=logger_main, + args=( + self, + self.logs, + ), + daemon=True, ) self.threads["logger"] = logger_thread @@ -128,8 +133,11 @@ class DevicePoolExecutor: def create_progress_worker(self) -> None: progress_thread = Thread( name="onnx-web progress", - target=partial(progress_main, self), - args=(self.progress,), + target=progress_main, + args=( + self, + self.progress, + ), daemon=True, ) self.threads["progress"] = progress_thread @@ -417,6 +425,7 @@ class DevicePoolExecutor: ) self.context[progress.device].set_cancel() + def logger_main(pool: DevicePoolExecutor, logs: Queue): logger.trace("checking in from logger worker thread") @@ -433,6 +442,7 @@ def logger_main(pool: DevicePoolExecutor, logs: Queue): except Exception: logger.exception("error in log worker") + def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"): logger.trace("checking in from progress worker thread") while True: