From 7cf5554beff501988c0e1c740170d4e57e0c8e5e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Mar 2023 15:12:09 -0500 Subject: [PATCH] feat(api): add error flag to image ready response --- api/onnx_web/server/api.py | 12 ++-- api/onnx_web/worker/command.py | 43 ++++++++++++++ api/onnx_web/worker/context.py | 35 +++++------ api/onnx_web/worker/pool.py | 102 ++++++++++++++++++--------------- api/onnx_web/worker/worker.py | 24 ++++---- gui/src/client/api.ts | 2 + 6 files changed, 141 insertions(+), 77 deletions(-) create mode 100644 api/onnx_web/worker/command.py diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 5cad5604..4bbec176 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -50,9 +50,11 @@ from .utils import wrap_route logger = getLogger(__name__) -def ready_reply(ready: bool, progress: int = 0): +def ready_reply(ready: bool, progress: int = 0, error: bool = False, cancel: bool = False): return jsonify( { + "cancel": cancel, + "error": error, "progress": progress, "ready": ready, } @@ -437,7 +439,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor): output_file = sanitize_name(output_file) cancel = pool.cancel(output_file) - return ready_reply(cancel) + return ready_reply(cancel == False, cancel=cancel) def ready(context: ServerContext, pool: DevicePoolExecutor): @@ -446,14 +448,14 @@ def ready(context: ServerContext, pool: DevicePoolExecutor): return error_reply("output name is required") output_file = sanitize_name(output_file) - done, progress = pool.done(output_file) + progress = pool.done(output_file) - if done is None: + if progress is None: output = base_join(context.output_path, output_file) if path.exists(output): return ready_reply(True) - return ready_reply(done or False, progress=progress) + return ready_reply(progress.finished, progress=progress.progress, error=progress.error, cancel=progress.cancel) def status(context: ServerContext, pool: DevicePoolExecutor): diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py new file mode 100644 index 00000000..5f3a0728 --- /dev/null +++ b/api/onnx_web/worker/command.py @@ -0,0 +1,43 @@ +from typing import Callable, Any + +class ProgressCommand(): + device: str + job: str + finished: bool + progress: int + cancel: bool + error: bool + + def __init__( + self, + job: str, + device: str, + finished: bool, + progress: int, + cancel: bool = False, + error: bool = False, + ): + self.job = job + self.device = device + self.finished = finished + self.progress = progress + self.cancel = cancel + self.error = error + +class JobCommand(): + name: str + fn: Callable[..., None] + args: Any + kwargs: dict[str, Any] + + def __init__( + self, + name: str, + fn: Callable[..., None], + args: Any, + kwargs: dict[str, Any], + ): + self.name = name + self.fn = fn + self.args = args + self.kwargs = kwargs diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 5d2013a2..44c043a5 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Tuple from torch.multiprocessing import Queue, Value +from .command import JobCommand, ProgressCommand from ..params import DeviceParams logger = getLogger(__name__) @@ -15,9 +16,9 @@ ProgressCallback = Callable[[int, int, Any], None] class WorkerContext: cancel: "Value[bool]" job: str - pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]" + pending: "Queue[JobCommand]" current: "Value[int]" - progress: "Queue[Tuple[str, str, int]]" + progress: "Queue[ProgressCommand]" def __init__( self, @@ -25,16 +26,14 @@ class WorkerContext: device: DeviceParams, cancel: "Value[bool]", logs: "Queue[str]", - pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]", - progress: "Queue[Tuple[str, str, int]]", - finished: "Queue[Tuple[str, str]]", + pending: "Queue[JobCommand]", + progress: "Queue[ProgressCommand]", current: "Value[int]", ): self.job = job self.device = device self.cancel = cancel self.progress = progress - self.finished = finished self.logs = logs self.pending = pending self.current = current @@ -61,11 +60,7 @@ class WorkerContext: def get_progress_callback(self) -> ProgressCallback: def on_progress(step: int, timestep: int, latents: Any): on_progress.step = step - if self.is_cancelled(): - raise RuntimeError("job has been cancelled") - else: - logger.debug("setting progress for job %s to %s", self.job, step) - self.set_progress(step) + self.set_progress(step) return on_progress @@ -74,14 +69,22 @@ class WorkerContext: self.cancel.value = cancel def set_progress(self, progress: int) -> None: - self.progress.put((self.job, self.device.device, progress), block=False) + if self.is_cancelled(): + raise RuntimeError("job has been cancelled") + else: + logger.debug("setting progress for job %s to %s", self.job, progress) + self.progress.put(ProgressCommand(self.job, self.device.device, False, progress, self.is_cancelled(), False), block=False) def set_finished(self) -> None: - self.finished.put((self.job, self.device.device), block=False) + logger.debug("setting finished for job %s", self.job) + self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), False), block=False) - def clear_flags(self) -> None: - self.set_cancel(False) - self.set_progress(0) + def set_failed(self) -> None: + logger.warning("setting failure for job %s", self.job) + try: + self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), True), block=False) + except: + logger.exception("error setting failure on job %s", self.job) class JobStatus: diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index e222a6b4..ca382881 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -8,6 +8,7 @@ from torch.multiprocessing import Process, Queue, Value from ..params import DeviceParams from ..server import ServerContext +from .command import JobCommand, ProgressCommand from .context import WorkerContext from .worker import worker_main @@ -24,18 +25,17 @@ class DevicePoolExecutor: leaking: List[Tuple[str, Process]] context: Dict[str, WorkerContext] # Device -> Context current: Dict[str, "Value[int]"] - pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"] + pending: Dict[str, "Queue[JobCommand]"] threads: Dict[str, Thread] workers: Dict[str, Process] - active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus] + active_jobs: Dict[str, ProgressCommand] # Device -> job progress cancelled_jobs: List[str] - finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus] + finished_jobs: List[ProgressCommand] total_jobs: Dict[str, int] # Device -> job count - logs: "Queue" - progress: "Queue[Tuple[str, str, int]]" - finished: "Queue[Tuple[str, str]]" + logs: "Queue[str]" + progress: "Queue[ProgressCommand]" def __init__( self, @@ -142,18 +142,27 @@ class DevicePoolExecutor: logger_thread.start() def create_progress_worker(self) -> None: - def progress_worker(progress: Queue): + def update_job(progress: ProgressCommand): + if progress.finished: + logger.info("job has finished: %s", progress.job) + self.finished_jobs.append(progress) + del self.active_jobs[progress.job] + self.join_leaking() + else: + logger.debug("progress update for job: %s to %s", progress.job, progress.progress) + self.active_jobs[progress.job] = progress + if progress.job in self.cancelled_jobs: + logger.debug( + "setting flag for cancelled job: %s on %s", progress.job, progress.device + ) + self.context[progress.device].set_cancel() + + def progress_worker(queue: "Queue[ProgressCommand]"): logger.trace("checking in from progress worker thread") while True: try: - job, device, value = progress.get(timeout=(self.join_timeout / 2)) - logger.debug("progress update for job: %s to %s", job, value) - self.active_jobs[job] = (device, value) - if job in self.cancelled_jobs: - logger.debug( - "setting flag for cancelled job: %s on %s", job, device - ) - self.context[device].set_cancel() + progress = queue.get(timeout=(self.join_timeout / 2)) + update_job(progress) except Empty: pass except ValueError: @@ -178,12 +187,7 @@ class DevicePoolExecutor: while True: try: job, device = finished.get(timeout=(self.join_timeout / 2)) - logger.info("job has been finished: %s", job) - context = self.context[device] - _device, progress = self.active_jobs[job] - self.finished_jobs.append((job, progress, context.cancel.value)) - del self.active_jobs[job] - self.join_leaking() + except Empty: pass except ValueError: @@ -232,37 +236,36 @@ class DevicePoolExecutor: should be cancelled on the next progress callback. """ - self.cancelled_jobs.append(key) + for job in self.finished_jobs: + if job.job == key: + logger.debug("cannot cancel finished job: %s", key) + return False if key not in self.active_jobs: - logger.debug("cancelled job has not been started yet: %s", key) - return True - - device, _progress = self.active_jobs[key] - logger.info("cancelling job %s, active on device %s", key, device) - - context = self.context[device] - context.set_cancel() + logger.debug("cancelled job is not active: %s", key) + else: + job = self.active_jobs[key] + logger.info("cancelling job %s, active on device %s", key, job.device) + self.cancelled_jobs.append(key) return True - def done(self, key: str) -> Tuple[Optional[bool], int]: + def done(self, key: str) -> Optional[ProgressCommand]: """ Check if a job has been finished and report the last progress update. If the job is still active or pending, the first item will be False. If the job is not finished or active, the first item will be None. """ - for k, p, c in self.finished_jobs: - if k == key: - return (True, p) + for job in self.finished_jobs: + if job.job == key: + return job if key not in self.active_jobs: logger.debug("checking status for unknown job: %s", key) - return (None, 0) + return None - _device, progress = self.active_jobs[key] - return (False, progress) + return self.active_jobs[key] def join(self): logger.info("stopping worker pool") @@ -387,22 +390,29 @@ class DevicePoolExecutor: logger.debug("job count for device %s: %s", device, self.total_jobs[device]) self.recycle() - self.pending[device].put((key, fn, args, kwargs), block=False) + self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False) - def status(self) -> List[Tuple[str, int, bool, bool]]: + def status(self) -> List[Tuple[str, int, bool, bool, bool]]: history = [ - (name, progress, False, name in self.cancelled_jobs) - for name, (_device, progress) in self.active_jobs.items() + ( + name, + job.progress, + job.finished, + job.cancel, + job.error, + ) + for name, job in self.active_jobs.items() ] history.extend( [ ( - name, - progress, - True, - cancel, + job.job, + job.progress, + job.finished, + job.cancel, + job.error, ) - for name, progress, cancel in self.finished_jobs + for job in self.finished_jobs ] ) return history diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 740d234b..2d07289d 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -39,32 +39,36 @@ def worker_main(context: WorkerContext, server: ServerContext): ) exit(EXIT_REPLACED) - name, fn, args, kwargs = context.pending.get(timeout=1.0) - logger.info("worker for %s got job: %s", context.device.device, name) + job = context.pending.get(timeout=1.0) + logger.info("worker for %s got job: %s", context.device.device, job.name) - context.job = name # TODO: hax - context.clear_flags() - logger.info("starting job: %s", name) - fn(context, *args, **kwargs) - logger.info("job succeeded: %s", name) + context.job = job.name # TODO: hax + logger.info("starting job: %s", job.name) + context.set_progress(0) + job.fn(context, *job.args, **job.kwargs) + logger.info("job succeeded: %s", job.name) context.set_finished() except Empty: pass except KeyboardInterrupt: logger.info("worker got keyboard interrupt") + context.set_failed() exit(EXIT_INTERRUPT) except ValueError as e: - logger.info( - "value error in worker, exiting: %s", - format_exception(type(e), e, e.__traceback__), + logger.exception( + "value error in worker, exiting: %s" ) + context.set_failed() exit(EXIT_ERROR) except Exception as e: e_str = str(e) if "Failed to allocate memory" in e_str or "out of memory" in e_str: logger.error("detected out-of-memory error, exiting: %s", e) + context.set_failed() exit(EXIT_MEMORY) else: logger.exception( "error while running job", ) + context.set_failed() + # carry on diff --git a/gui/src/client/api.ts b/gui/src/client/api.ts index 606966a1..44414b1c 100644 --- a/gui/src/client/api.ts +++ b/gui/src/client/api.ts @@ -175,6 +175,8 @@ export interface ImageResponse { * Status response from the ready endpoint. */ export interface ReadyResponse { + cancel: boolean; + error: boolean; progress: number; ready: boolean; }