diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 8dfb7715..59f55fdd 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -22,12 +22,16 @@ class WorkerContext: device: DeviceParams, pending: "Queue[Any]", progress: "Value[int]", + logs: "Queue[str]", + finished: "Value[bool]", ): self.key = key self.cancel = cancel self.device = device self.pending = pending self.progress = progress + self.logs = logs + self.finished = finished def is_cancelled(self) -> bool: return self.cancel.value diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 7dad4bee..0721896e 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -14,10 +14,11 @@ logger = getLogger(__name__) class DevicePoolExecutor: devices: List[DeviceParams] = None - finished: List[Tuple[str, int]] = None + finished: Dict[str, "Value[bool]"] = None pending: Dict[str, "Queue[WorkerContext]"] = None - progress: Dict[str, Value] = None + progress: Dict[str, "Value[int]"] = None workers: Dict[str, Process] = None + jobs: Dict[str, str] = None def __init__( self, @@ -27,13 +28,14 @@ class DevicePoolExecutor: ): self.server = server self.devices = devices - self.finished = [] + self.finished = {} self.finished_limit = finished_limit self.context = {} self.locks = {} self.pending = {} self.progress = {} self.workers = {} + self.jobs = {} # Dict[Output, Device] # TODO: make this a method logger.debug("starting log worker") @@ -53,11 +55,13 @@ class DevicePoolExecutor: lock = Lock() self.locks[name] = lock cancel = Value("B", False, lock=lock) + finished = Value("B", False) + self.finished[name] = finished progress = Value("I", 0) # , lock=lock) # needs its own lock for some reason. TODO: why? self.progress[name] = progress pending = Queue() self.pending[name] = pending - context = WorkerContext(name, cancel, device, pending, progress) + context = WorkerContext(name, cancel, device, pending, progress, self.log_queue, finished) self.context[name] = context logger.debug("starting worker for device %s", device) @@ -73,12 +77,16 @@ class DevicePoolExecutor: raise NotImplementedError() def done(self, key: str) -> Tuple[Optional[bool], int]: - for k, progress in self.finished: - if key == k: - return (True, progress) + if not key in self.jobs: + logger.warn("checking status for unknown key: %s", key) + return (None, 0) + + device = self.jobs[key] + finished = self.finished[device] + progress = self.progress[device] + + return (finished.value, progress.value) - logger.warn("checking status for unknown key: %s", key) - return (None, 0) def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible @@ -140,6 +148,8 @@ class DevicePoolExecutor: queue = self.pending[device.device] queue.put((fn, args, kwargs)) + self.jobs[key] = device.device + def status(self) -> List[Tuple[str, int, bool, int]]: pending = [ diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 3497da70..4edfb4bb 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -32,8 +32,20 @@ def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): logger.info("got job: %s", job) try: fn, args, kwargs = job + name = args[3][0] + logger.info("starting job: %s", name) + with context.finished.get_lock(): + context.finished.value = False + + with context.progress.get_lock(): + context.progress.value = 0 + fn(context, *args, **kwargs) - logger.info("finished job") + logger.info("finished job: %s", name) + + with context.finished.get_lock(): + context.finished.value = True + except Exception as e: logger.error(format_exception(type(e), e, e.__traceback__))