From a37d1a455015e6328d4718fa38082aa337cf34f7 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 20:37:22 -0600 Subject: [PATCH] use progress queue --- api/onnx_web/worker/context.py | 16 ++++-------- api/onnx_web/worker/pool.py | 45 +++++++++++++++++----------------- api/onnx_web/worker/worker.py | 4 +-- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 67bc4eb6..2daf23d0 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -22,11 +22,10 @@ class WorkerContext: key: str, device: DeviceParams, cancel: "Value[bool]" = None, - progress: "Value[int]" = None, - finished: "Queue[str]" = None, logs: "Queue[str]" = None, pending: "Queue[Any]" = None, - started: "Queue[Tuple[str, str]]" = None, + progress: "Queue[Tuple[str, int]]" = None, + finished: "Queue[str]" = None, ): self.key = key self.device = device @@ -35,7 +34,6 @@ class WorkerContext: self.finished = finished self.logs = logs self.pending = pending - self.started = started def is_cancelled(self) -> bool: return self.cancel.value @@ -65,14 +63,10 @@ class WorkerContext: self.cancel.value = cancel def set_progress(self, progress: int) -> None: - with self.progress.get_lock(): - self.progress.value = progress + self.progress.put((self.key, self.device.device, progress)) - def put_finished(self, job: str) -> None: - self.finished.put((job, self.device.device)) - - def put_started(self, job: str) -> None: - self.started.put((job, self.device.device)) + def set_finished(self) -> None: + self.finished.put((self.key, self.device.device)) def clear_flags(self) -> None: self.set_cancel(False) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 97cc683a..589b7938 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -18,7 +18,7 @@ class DevicePoolExecutor: devices: List[DeviceParams] = None pending: Dict[str, "Queue[WorkerContext]"] = None workers: Dict[str, Process] = None - active_jobs: Dict[str, str] = None + active_jobs: Dict[str, Tuple[str, int]] = None finished_jobs: List[Tuple[str, int, bool]] = None def __init__( @@ -40,7 +40,7 @@ class DevicePoolExecutor: self.finished_jobs = [] self.total_jobs = 0 # TODO: turn this into a Dict per-worker - self.started = Queue() + self.progress = Queue() self.finished = Queue() self.create_logger_worker() @@ -72,11 +72,10 @@ class DevicePoolExecutor: name, device, cancel=Value("B", False), - progress=Value("I", 0), + progress=self.progress, finished=self.finished, logs=self.log_queue, pending=pending, - started=self.started, ) self.context[name] = context self.workers[name] = Process(target=worker_init, args=(context, self.server)) @@ -85,30 +84,32 @@ class DevicePoolExecutor: self.workers[name].start() def create_queue_workers(self) -> None: - def started_worker(pending: Queue): - logger.info("checking in from started thread") + def progress_worker(progress: Queue): + logger.info("checking in from progress worker thread") while True: - job, device = pending.get() - logger.info("job has been started: %s", job) - self.active_jobs[device] = job + job, device, value = progress.get() + logger.info("progress update for job: %s, %s", job, value) + self.active_jobs[job] = (device, value) def finished_worker(finished: Queue): - logger.info("checking in from finished thread") + logger.info("checking in from finished worker thread") while True: job, device = finished.get() logger.info("job has been finished: %s", job) - context = self.get_job_context(job) + context = self.context[device] + _device, progress = self.active_jobs[job] self.finished_jobs.append( - (job, context.progress.value, context.cancel.value) + (job, progress, context.cancel.value) ) + del self.active_jobs[job] - self.started_thread = Thread(target=started_worker, args=(self.started,)) - self.started_thread.start() + self.progress_thread = Thread(target=progress_worker, args=(self.progress,)) + self.progress_thread.start() self.finished_thread = Thread(target=finished_worker, args=(self.finished,)) self.finished_thread.start() def get_job_context(self, key: str) -> WorkerContext: - device = self.active_jobs[key] + device, _progress = self.active_jobs[key] return self.context[device] def cancel(self, key: str) -> bool: @@ -141,8 +142,8 @@ class DevicePoolExecutor: return (None, 0) # TODO: prune here, maybe? - context = self.get_job_context(key) - return (False, context.progress.value) + _device, progress = self.active_jobs[key] + return (False, progress) def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible @@ -164,7 +165,7 @@ class DevicePoolExecutor: return lowest_devices[0] def join(self): - self.started_thread.join(self.join_timeout) + self.progress_thread.join(self.join_timeout) self.finished_thread.join(self.join_timeout) for device, worker in self.workers.items(): @@ -226,12 +227,12 @@ class DevicePoolExecutor: ( name, self.workers[name].is_alive(), - context.pending.qsize(), - context.cancel.value, + self.context[device].pending.qsize(), + self.context[device].cancel.value, False, - context.progress.value, + progress, ) - for name, context in self.context.items() + for name, device, progress in self.active_jobs ] pending.extend( [ diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index db23540f..cbd3afa7 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -37,9 +37,9 @@ def worker_init(context: WorkerContext, server: ServerContext): name = args[3][0] try: + context.key = name # TODO: hax context.clear_flags() logger.info("starting job: %s", name) - context.put_started(name) fn(context, *args, **kwargs) logger.info("job succeeded: %s", name) except Exception as e: @@ -48,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext): format_exception(type(e), e, e.__traceback__), ) finally: - context.put_finished(name) + context.set_finished() logger.info("finished job: %s", name)