1
0
Fork 0

fix(api): clear job cancelled flag when worker starts a new job (#269)

This commit is contained in:
Sean Sube 2023-03-19 17:57:14 -05:00
parent ba0767179c
commit aefa5b4613
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 20 additions and 10 deletions

View File

@ -40,6 +40,10 @@ class WorkerContext:
self.active_pid = active_pid self.active_pid = active_pid
self.last_progress = None self.last_progress = None
def start(self, job: str) -> None:
self.job = job
self.set_cancel(cancel=False)
def is_cancelled(self) -> bool: def is_cancelled(self) -> bool:
return self.cancel.value return self.cancel.value
@ -92,7 +96,7 @@ class WorkerContext:
block=False, block=False,
) )
def set_finished(self) -> None: def finish(self) -> None:
logger.debug("setting finished for job %s", self.job) logger.debug("setting finished for job %s", self.job)
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(
self.job, self.job,
@ -107,7 +111,7 @@ class WorkerContext:
block=False, block=False,
) )
def set_failed(self) -> None: def fail(self) -> None:
logger.warning("setting failure for job %s", self.job) logger.warning("setting failure for job %s", self.job)
try: try:
self.last_progress = ProgressCommand( self.last_progress = ProgressCommand(

View File

@ -22,7 +22,7 @@ def worker_main(context: WorkerContext, server: ServerContext):
apply_patches(server) apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device)) setproctitle("onnx-web worker: %s" % (context.device.device))
logger.trace("checking in from worker, %s", get_available_providers()) logger.trace("checking in from worker with providers: %s", get_available_providers())
# make leaking workers easier to recycle # make leaking workers easier to recycle
context.progress.cancel_join_thread() context.progress.cancel_join_thread()
@ -37,34 +37,40 @@ def worker_main(context: WorkerContext, server: ServerContext):
) )
exit(EXIT_REPLACED) exit(EXIT_REPLACED)
# wait briefly for the next job
job = context.pending.get(timeout=1.0) job = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, job.name) logger.info("worker %s got job: %s", context.device.device, job.name)
context.job = job.name # TODO: hax # clear flags and save the job name
context.start(job.name)
logger.info("starting job: %s", job.name) logger.info("starting job: %s", job.name)
# reset progress, which does a final check for cancellation
context.set_progress(0) context.set_progress(0)
job.fn(context, *job.args, **job.kwargs) job.fn(context, *job.args, **job.kwargs)
# confirm completion of the job
logger.info("job succeeded: %s", job.name) logger.info("job succeeded: %s", job.name)
context.set_finished() context.finish()
except Empty: except Empty:
pass pass
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("worker got keyboard interrupt") logger.info("worker got keyboard interrupt")
context.set_failed() context.fail()
exit(EXIT_INTERRUPT) exit(EXIT_INTERRUPT)
except ValueError: except ValueError:
logger.exception("value error in worker, exiting: %s") logger.exception("value error in worker, exiting: %s")
context.set_failed() context.fail()
exit(EXIT_ERROR) exit(EXIT_ERROR)
except Exception as e: except Exception as e:
e_str = str(e) e_str = str(e)
if "Failed to allocate memory" in e_str or "out of memory" in e_str: if "Failed to allocate memory" in e_str or "out of memory" in e_str:
logger.error("detected out-of-memory error, exiting: %s", e) logger.error("detected out-of-memory error, exiting: %s", e)
context.set_failed() context.fail()
exit(EXIT_MEMORY) exit(EXIT_MEMORY)
else: else:
logger.exception( logger.exception(
"error while running job", "error while running job",
) )
context.set_failed() context.fail()
# carry on # carry on