diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index ae386893..44ecc824 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -1,4 +1,5 @@ from collections import Counter +from functools import partial from logging import getLogger from queue import Empty from threading import Thread @@ -106,7 +107,7 @@ class DevicePoolExecutor: self.context[name] = context worker = Process( name=f"onnx-web worker: {name}", - target=worker_main, + target=partial(worker_main, self), args=(context, self.server), ) @@ -116,24 +117,8 @@ class DevicePoolExecutor: current.value = worker.pid def create_logger_worker(self) -> None: - def logger_worker(logs: Queue): - logger.trace("checking in from logger worker thread") - - while True: - try: - job = logs.get(timeout=(self.join_timeout / 2)) - with open("worker.log", "w") as f: - logger.info("got log: %s", job) - f.write(str(job) + "\n\n") - except Empty: - pass - except ValueError: - break - except Exception: - logger.exception("error in log worker") - logger_thread = Thread( - name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True + name="onnx-web logger", target=logger_main, args=(self.logs,), daemon=True ) self.threads["logger"] = logger_thread @@ -141,22 +126,9 @@ class DevicePoolExecutor: logger_thread.start() def create_progress_worker(self) -> None: - def progress_worker(queue: "Queue[ProgressCommand]"): - logger.trace("checking in from progress worker thread") - while True: - try: - progress = queue.get(timeout=(self.join_timeout / 2)) - self.update_job(progress) - except Empty: - pass - except ValueError: - break - except Exception: - logger.exception("error in progress worker") - progress_thread = Thread( name="onnx-web progress", - target=progress_worker, + target=partial(progress_main, self), args=(self.progress,), daemon=True, ) @@ -444,3 +416,32 @@ class DevicePoolExecutor: progress.device, ) self.context[progress.device].set_cancel() + +def logger_main(pool: DevicePoolExecutor, logs: Queue): + logger.trace("checking in from logger worker thread") + + while True: + try: + job = logs.get(timeout=(pool.join_timeout / 2)) + with open("worker.log", "w") as f: + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") + except Empty: + pass + except ValueError: + break + except Exception: + logger.exception("error in log worker") + +def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"): + logger.trace("checking in from progress worker thread") + while True: + try: + progress = queue.get(timeout=(pool.join_timeout / 2)) + pool.update_job(progress) + except Empty: + pass + except ValueError: + break + except Exception: + logger.exception("error in progress worker")