From aefa5b4613c043811a00488bf731671dd2347609 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 19 Mar 2023 17:57:14 -0500 Subject: [PATCH] fix(api): clear job cancelled flag when worker starts a new job (#269) --- api/onnx_web/worker/context.py | 8 ++++++-- api/onnx_web/worker/worker.py | 22 ++++++++++++++-------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index d4afe10c..a88d06e7 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -40,6 +40,10 @@ class WorkerContext: self.active_pid = active_pid self.last_progress = None + def start(self, job: str) -> None: + self.job = job + self.set_cancel(cancel=False) + def is_cancelled(self) -> bool: return self.cancel.value @@ -92,7 +96,7 @@ class WorkerContext: block=False, ) - def set_finished(self) -> None: + def finish(self) -> None: logger.debug("setting finished for job %s", self.job) self.last_progress = ProgressCommand( self.job, @@ -107,7 +111,7 @@ class WorkerContext: block=False, ) - def set_failed(self) -> None: + def fail(self) -> None: logger.warning("setting failure for job %s", self.job) try: self.last_progress = ProgressCommand( diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 6afd60b3..e7bb2671 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -22,7 +22,7 @@ def worker_main(context: WorkerContext, server: ServerContext): apply_patches(server) setproctitle("onnx-web worker: %s" % (context.device.device)) - logger.trace("checking in from worker, %s", get_available_providers()) + logger.trace("checking in from worker with providers: %s", get_available_providers()) # make leaking workers easier to recycle context.progress.cancel_join_thread() @@ -37,34 +37,40 @@ def worker_main(context: WorkerContext, server: ServerContext): ) exit(EXIT_REPLACED) + # wait briefly for the next job job = context.pending.get(timeout=1.0) - logger.info("worker for %s got job: %s", context.device.device, job.name) + logger.info("worker %s got job: %s", context.device.device, job.name) - context.job = job.name # TODO: hax + # clear flags and save the job name + context.start(job.name) logger.info("starting job: %s", job.name) + + # reset progress, which does a final check for cancellation context.set_progress(0) job.fn(context, *job.args, **job.kwargs) + + # confirm completion of the job logger.info("job succeeded: %s", job.name) - context.set_finished() + context.finish() except Empty: pass except KeyboardInterrupt: logger.info("worker got keyboard interrupt") - context.set_failed() + context.fail() exit(EXIT_INTERRUPT) except ValueError: logger.exception("value error in worker, exiting: %s") - context.set_failed() + context.fail() exit(EXIT_ERROR) except Exception as e: e_str = str(e) if "Failed to allocate memory" in e_str or "out of memory" in e_str: logger.error("detected out-of-memory error, exiting: %s", e) - context.set_failed() + context.fail() exit(EXIT_MEMORY) else: logger.exception( "error while running job", ) - context.set_failed() + context.fail() # carry on