1
0
Fork 0

lint(api): extract worker thread main functions (#279)

This commit is contained in:
Sean Sube 2023-03-22 22:55:34 -05:00
parent 4dd68ea6b6
commit 86c1b29c31
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 33 additions and 32 deletions

View File

@ -1,4 +1,5 @@
from collections import Counter from collections import Counter
from functools import partial
from logging import getLogger from logging import getLogger
from queue import Empty from queue import Empty
from threading import Thread from threading import Thread
@ -106,7 +107,7 @@ class DevicePoolExecutor:
self.context[name] = context self.context[name] = context
worker = Process( worker = Process(
name=f"onnx-web worker: {name}", name=f"onnx-web worker: {name}",
target=worker_main, target=partial(worker_main, self),
args=(context, self.server), args=(context, self.server),
) )
@ -116,24 +117,8 @@ class DevicePoolExecutor:
current.value = worker.pid current.value = worker.pid
def create_logger_worker(self) -> None: def create_logger_worker(self) -> None:
def logger_worker(logs: Queue):
logger.trace("checking in from logger worker thread")
while True:
try:
job = logs.get(timeout=(self.join_timeout / 2))
with open("worker.log", "w") as f:
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
except Empty:
pass
except ValueError:
break
except Exception:
logger.exception("error in log worker")
logger_thread = Thread( logger_thread = Thread(
name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True name="onnx-web logger", target=logger_main, args=(self.logs,), daemon=True
) )
self.threads["logger"] = logger_thread self.threads["logger"] = logger_thread
@ -141,22 +126,9 @@ class DevicePoolExecutor:
logger_thread.start() logger_thread.start()
def create_progress_worker(self) -> None: def create_progress_worker(self) -> None:
def progress_worker(queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread")
while True:
try:
progress = queue.get(timeout=(self.join_timeout / 2))
self.update_job(progress)
except Empty:
pass
except ValueError:
break
except Exception:
logger.exception("error in progress worker")
progress_thread = Thread( progress_thread = Thread(
name="onnx-web progress", name="onnx-web progress",
target=progress_worker, target=partial(progress_main, self),
args=(self.progress,), args=(self.progress,),
daemon=True, daemon=True,
) )
@ -444,3 +416,32 @@ class DevicePoolExecutor:
progress.device, progress.device,
) )
self.context[progress.device].set_cancel() self.context[progress.device].set_cancel()
def logger_main(pool: DevicePoolExecutor, logs: Queue):
logger.trace("checking in from logger worker thread")
while True:
try:
job = logs.get(timeout=(pool.join_timeout / 2))
with open("worker.log", "w") as f:
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
except Empty:
pass
except ValueError:
break
except Exception:
logger.exception("error in log worker")
def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread")
while True:
try:
progress = queue.get(timeout=(pool.join_timeout / 2))
pool.update_job(progress)
except Empty:
pass
except ValueError:
break
except Exception:
logger.exception("error in progress worker")