1
0
Fork 0

lint(api): extract worker thread main functions (#279)

This commit is contained in:
Sean Sube 2023-03-22 22:55:34 -05:00
parent 4dd68ea6b6
commit 86c1b29c31
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 33 additions and 32 deletions

View File

@ -1,4 +1,5 @@
from collections import Counter
from functools import partial
from logging import getLogger
from queue import Empty
from threading import Thread
@ -106,7 +107,7 @@ class DevicePoolExecutor:
self.context[name] = context
worker = Process(
name=f"onnx-web worker: {name}",
target=worker_main,
target=partial(worker_main, self),
args=(context, self.server),
)
@ -116,24 +117,8 @@ class DevicePoolExecutor:
current.value = worker.pid
def create_logger_worker(self) -> None:
def logger_worker(logs: Queue):
logger.trace("checking in from logger worker thread")
while True:
try:
job = logs.get(timeout=(self.join_timeout / 2))
with open("worker.log", "w") as f:
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
except Empty:
pass
except ValueError:
break
except Exception:
logger.exception("error in log worker")
logger_thread = Thread(
name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True
name="onnx-web logger", target=logger_main, args=(self.logs,), daemon=True
)
self.threads["logger"] = logger_thread
@ -141,22 +126,9 @@ class DevicePoolExecutor:
logger_thread.start()
def create_progress_worker(self) -> None:
def progress_worker(queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread")
while True:
try:
progress = queue.get(timeout=(self.join_timeout / 2))
self.update_job(progress)
except Empty:
pass
except ValueError:
break
except Exception:
logger.exception("error in progress worker")
progress_thread = Thread(
name="onnx-web progress",
target=progress_worker,
target=partial(progress_main, self),
args=(self.progress,),
daemon=True,
)
@ -444,3 +416,32 @@ class DevicePoolExecutor:
progress.device,
)
self.context[progress.device].set_cancel()
def logger_main(pool: DevicePoolExecutor, logs: Queue):
logger.trace("checking in from logger worker thread")
while True:
try:
job = logs.get(timeout=(pool.join_timeout / 2))
with open("worker.log", "w") as f:
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
except Empty:
pass
except ValueError:
break
except Exception:
logger.exception("error in log worker")
def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread")
while True:
try:
progress = queue.get(timeout=(pool.join_timeout / 2))
pool.update_job(progress)
except Empty:
pass
except ValueError:
break
except Exception:
logger.exception("error in progress worker")