diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index bb19466c..86a5896d 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -39,7 +39,7 @@ def main(): CORS(app, origins=context.cors_origin) # any is a fake device, should not be in the pool - pool = DevicePoolExecutor([p for p in get_available_platforms() if p.device != "any"]) + pool = DevicePoolExecutor(context, [p for p in get_available_platforms() if p.device != "any"]) # register routes register_static_routes(app, context, pool) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 8f6a6d13..7dad4bee 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -5,6 +5,7 @@ from torch.multiprocessing import Lock, Process, Value from typing import Callable, Dict, List, Optional, Tuple from ..params import DeviceParams +from ..server import ServerContext from .context import WorkerContext from .worker import logger_init, worker_init @@ -20,9 +21,11 @@ class DevicePoolExecutor: def __init__( self, + server: ServerContext, devices: List[DeviceParams], finished_limit: int = 10, ): + self.server = server self.devices = devices self.finished = [] self.finished_limit = finished_limit @@ -32,9 +35,12 @@ class DevicePoolExecutor: self.progress = {} self.workers = {} + # TODO: make this a method logger.debug("starting log worker") self.log_queue = Queue() - self.logger = Process(target=logger_init, args=(self.lock, self.log_queue)) + log_lock = Lock() + self.locks["logger"] = log_lock + self.logger = Process(target=logger_init, args=(log_lock, self.log_queue)) self.logger.start() logger.debug("testing log worker") @@ -43,10 +49,11 @@ class DevicePoolExecutor: # create a pending queue and progress value for each device for device in devices: name = device.device + # TODO: make this a method lock = Lock() self.locks[name] = lock cancel = Value("B", False, lock=lock) - progress = Value("I", 0, lock=lock) + progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why? self.progress[name] = progress pending = Queue() self.pending[name] = pending @@ -54,7 +61,7 @@ class DevicePoolExecutor: self.context[name] = context logger.debug("starting worker for device %s", device) - self.workers[name] = Process(target=worker_init, args=(lock, context)) + self.workers[name] = Process(target=worker_init, args=(lock, context, server)) self.workers[name].start() def cancel(self, key: str) -> bool: diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index f5d3689c..3497da70 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -5,9 +5,11 @@ from torch.multiprocessing import Lock, Queue from traceback import format_exception from .context import WorkerContext +from ..server import ServerContext, apply_patches logger = getLogger(__name__) + def logger_init(lock: Lock, logs: Queue): with lock: logger.info("checking in from logger, %s", lock) @@ -19,10 +21,12 @@ def logger_init(lock: Lock, logs: Queue): f.write(str(job) + "\n\n") -def worker_init(lock: Lock, context: WorkerContext): +def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): with lock: logger.info("checking in from worker, %s, %s", lock, get_available_providers()) + apply_patches(server) + while True: job = context.pending.get() logger.info("got job: %s", job)