diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 07cf082a..9bd5e819 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -2,6 +2,7 @@ from enum import IntEnum from logging import getLogger from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import torch from onnxruntime import GraphOptimizationLevel, SessionOptions logger = getLogger(__name__) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index fa2e5f57..8f6a6d13 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -26,7 +26,8 @@ class DevicePoolExecutor: self.devices = devices self.finished = [] self.finished_limit = finished_limit - self.lock = Lock() + self.context = {} + self.locks = {} self.pending = {} self.progress = {} self.workers = {} @@ -42,15 +43,18 @@ class DevicePoolExecutor: # create a pending queue and progress value for each device for device in devices: name = device.device - cancel = Value("B", False, lock=self.lock) - progress = Value("I", 0, lock=self.lock) + lock = Lock() + self.locks[name] = lock + cancel = Value("B", False, lock=lock) + progress = Value("I", 0, lock=lock) + self.progress[name] = progress pending = Queue() - context = WorkerContext(name, cancel, device, pending, progress) self.pending[name] = pending - self.progress[name] = pending + context = WorkerContext(name, cancel, device, pending, progress) + self.context[name] = context logger.debug("starting worker for device %s", device) - self.workers[name] = Process(target=worker_init, args=(self.lock, context)) + self.workers[name] = Process(target=worker_init, args=(lock, context)) self.workers[name].start() def cancel(self, key: str) -> bool: @@ -135,6 +139,7 @@ class DevicePoolExecutor: ( device.device, self.pending[device.device].qsize(), + self.progress[device.device].value, self.workers[device.device].is_alive(), ) for device in self.devices diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index e5d46306..f5d3689c 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -1,7 +1,8 @@ from logging import getLogger +import torch # has to come before ORT from onnxruntime import get_available_providers from torch.multiprocessing import Lock, Queue -from traceback import print_exception +from traceback import format_exception from .context import WorkerContext @@ -30,5 +31,5 @@ def worker_init(lock: Lock, context: WorkerContext): fn(context, *args, **kwargs) logger.info("finished job") except Exception as e: - print_exception(type(e), e, e.__traceback__) + logger.error(format_exception(type(e), e, e.__traceback__))