1
0
Fork 0

pass pool to threads

This commit is contained in:
Sean Sube 2023-03-22 22:58:46 -05:00
parent 86c1b29c31
commit 6b4c046867
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 15 additions and 5 deletions

View File

@ -1,5 +1,4 @@
from collections import Counter from collections import Counter
from functools import partial
from logging import getLogger from logging import getLogger
from queue import Empty from queue import Empty
from threading import Thread from threading import Thread
@ -107,7 +106,7 @@ class DevicePoolExecutor:
self.context[name] = context self.context[name] = context
worker = Process( worker = Process(
name=f"onnx-web worker: {name}", name=f"onnx-web worker: {name}",
target=partial(worker_main, self), target=worker_main,
args=(context, self.server), args=(context, self.server),
) )
@ -118,7 +117,13 @@ class DevicePoolExecutor:
def create_logger_worker(self) -> None: def create_logger_worker(self) -> None:
logger_thread = Thread( logger_thread = Thread(
name="onnx-web logger", target=logger_main, args=(self.logs,), daemon=True name="onnx-web logger",
target=logger_main,
args=(
self,
self.logs,
),
daemon=True,
) )
self.threads["logger"] = logger_thread self.threads["logger"] = logger_thread
@ -128,8 +133,11 @@ class DevicePoolExecutor:
def create_progress_worker(self) -> None: def create_progress_worker(self) -> None:
progress_thread = Thread( progress_thread = Thread(
name="onnx-web progress", name="onnx-web progress",
target=partial(progress_main, self), target=progress_main,
args=(self.progress,), args=(
self,
self.progress,
),
daemon=True, daemon=True,
) )
self.threads["progress"] = progress_thread self.threads["progress"] = progress_thread
@ -417,6 +425,7 @@ class DevicePoolExecutor:
) )
self.context[progress.device].set_cancel() self.context[progress.device].set_cancel()
def logger_main(pool: DevicePoolExecutor, logs: Queue): def logger_main(pool: DevicePoolExecutor, logs: Queue):
logger.trace("checking in from logger worker thread") logger.trace("checking in from logger worker thread")
@ -433,6 +442,7 @@ def logger_main(pool: DevicePoolExecutor, logs: Queue):
except Exception: except Exception:
logger.exception("error in log worker") logger.exception("error in log worker")
def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"): def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread") logger.trace("checking in from progress worker thread")
while True: while True: