diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index b7eaf3a3..fa2e5f57 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -31,13 +31,14 @@ class DevicePoolExecutor: self.progress = {} self.workers = {} - log_queue = Queue() - logger_context = WorkerContext("logger", None, None, log_queue, None) - logger.debug("starting log worker") - self.logger = Process(target=logger_init, args=(self.lock, logger_context)) + self.log_queue = Queue() + self.logger = Process(target=logger_init, args=(self.lock, self.log_queue)) self.logger.start() + logger.debug("testing log worker") + self.log_queue.put("testing") + # create a pending queue and progress value for each device for device in devices: name = device.device @@ -52,9 +53,6 @@ class DevicePoolExecutor: self.workers[name] = Process(target=worker_init, args=(self.lock, context)) self.workers[name].start() - logger.debug("testing log worker") - log_queue.put("testing") - def cancel(self, key: str) -> bool: """ Cancel a job. If the job has not been started, this will cancel @@ -99,6 +97,9 @@ class DevicePoolExecutor: logger.info("stopping worker for device %s", device) worker.join(5) + if self.logger.is_alive(): + self.logger.join(5) + def prune(self): finished_count = len(self.finished) if finished_count > self.finished_limit: diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index cf47f85e..e5d46306 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,40 +1,34 @@ from logging import getLogger -from torch.multiprocessing import Lock -from time import sleep +from onnxruntime import get_available_providers +from torch.multiprocessing import Lock, Queue from traceback import print_exception from .context import WorkerContext logger = getLogger(__name__) -def logger_init(lock: Lock, context: WorkerContext): - logger.info("checking in from logger") +def logger_init(lock: Lock, logs: Queue): + with lock: + logger.info("checking in from logger, %s", lock) - with open("worker.log", "w") as f: - while True: - if context.pending.empty(): - logger.info("no logs, sleeping") - sleep(5) - else: - job = context.pending.get() - logger.info("got log: %s", job) - f.write(str(job) + "\n\n") + while True: + job = logs.get() + with open("worker.log", "w") as f: + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") def worker_init(lock: Lock, context: WorkerContext): - logger.info("checking in from worker") + with lock: + logger.info("checking in from worker, %s, %s", lock, get_available_providers()) while True: - if context.pending.empty(): - logger.info("no jobs, sleeping") - sleep(5) - else: - job = context.pending.get() - logger.info("got job: %s", job) - try: - fn, args, kwargs = job - fn(context, *args, **kwargs) - logger.info("finished job") - except Exception as e: - print_exception(type(e), e, e.__traceback__) + job = context.pending.get() + logger.info("got job: %s", job) + try: + fn, args, kwargs = job + fn(context, *args, **kwargs) + logger.info("finished job") + except Exception as e: + print_exception(type(e), e, e.__traceback__)