From d1565b056e5699f29e0b7145f50c134d40258b1e Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Mar 2023 15:16:41 -0500 Subject: [PATCH] apply lint, make missing images an error --- api/onnx_web/server/api.py | 15 +++++-- api/onnx_web/worker/command.py | 80 +++++++++++++++++----------------- api/onnx_web/worker/context.py | 42 +++++++++++++++--- api/onnx_web/worker/pool.py | 12 +++-- api/onnx_web/worker/worker.py | 7 +-- 5 files changed, 99 insertions(+), 57 deletions(-) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 4bbec176..5783ddc8 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -50,7 +50,9 @@ from .utils import wrap_route logger = getLogger(__name__) -def ready_reply(ready: bool, progress: int = 0, error: bool = False, cancel: bool = False): +def ready_reply( + ready: bool, progress: int = 0, error: bool = False, cancel: bool = False +): return jsonify( { "cancel": cancel, @@ -439,7 +441,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor): output_file = sanitize_name(output_file) cancel = pool.cancel(output_file) - return ready_reply(cancel == False, cancel=cancel) + return ready_reply(cancel is not False, cancel=cancel) def ready(context: ServerContext, pool: DevicePoolExecutor): @@ -454,8 +456,15 @@ def ready(context: ServerContext, pool: DevicePoolExecutor): output = base_join(context.output_path, output_file) if path.exists(output): return ready_reply(True) + else: + return ready_reply(True, error=True) # is a missing image really an error? yes will display the retry button - return ready_reply(progress.finished, progress=progress.progress, error=progress.error, cancel=progress.cancel) + return ready_reply( + progress.finished, + progress=progress.progress, + error=progress.error, + cancel=progress.cancel, + ) def status(context: ServerContext, pool: DevicePoolExecutor): diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index 5f3a0728..47eb75d3 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -1,43 +1,45 @@ -from typing import Callable, Any +from typing import Any, Callable -class ProgressCommand(): - device: str - job: str - finished: bool - progress: int - cancel: bool - error: bool - def __init__( - self, - job: str, - device: str, - finished: bool, - progress: int, - cancel: bool = False, - error: bool = False, - ): - self.job = job - self.device = device - self.finished = finished - self.progress = progress - self.cancel = cancel - self.error = error +class ProgressCommand: + device: str + job: str + finished: bool + progress: int + cancel: bool + error: bool -class JobCommand(): - name: str - fn: Callable[..., None] - args: Any - kwargs: dict[str, Any] + def __init__( + self, + job: str, + device: str, + finished: bool, + progress: int, + cancel: bool = False, + error: bool = False, + ): + self.job = job + self.device = device + self.finished = finished + self.progress = progress + self.cancel = cancel + self.error = error - def __init__( - self, - name: str, - fn: Callable[..., None], - args: Any, - kwargs: dict[str, Any], - ): - self.name = name - self.fn = fn - self.args = args - self.kwargs = kwargs + +class JobCommand: + name: str + fn: Callable[..., None] + args: Any + kwargs: dict[str, Any] + + def __init__( + self, + name: str, + fn: Callable[..., None], + args: Any, + kwargs: dict[str, Any], + ): + self.name = name + self.fn = fn + self.args = args + self.kwargs = kwargs diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 44c043a5..fc65c8a2 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -1,11 +1,11 @@ from logging import getLogger from os import getpid -from typing import Any, Callable, Tuple +from typing import Any, Callable from torch.multiprocessing import Queue, Value -from .command import JobCommand, ProgressCommand from ..params import DeviceParams +from .command import JobCommand, ProgressCommand logger = getLogger(__name__) @@ -73,17 +73,47 @@ class WorkerContext: raise RuntimeError("job has been cancelled") else: logger.debug("setting progress for job %s to %s", self.job, progress) - self.progress.put(ProgressCommand(self.job, self.device.device, False, progress, self.is_cancelled(), False), block=False) + self.progress.put( + ProgressCommand( + self.job, + self.device.device, + False, + progress, + self.is_cancelled(), + False, + ), + block=False, + ) def set_finished(self) -> None: logger.debug("setting finished for job %s", self.job) - self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), False), block=False) + self.progress.put( + ProgressCommand( + self.job, + self.device.device, + True, + self.get_progress(), + self.is_cancelled(), + False, + ), + block=False, + ) def set_failed(self) -> None: logger.warning("setting failure for job %s", self.job) try: - self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), True), block=False) - except: + self.progress.put( + ProgressCommand( + self.job, + self.device.device, + True, + self.get_progress(), + self.is_cancelled(), + True, + ), + block=False, + ) + except Exception: logger.exception("error setting failure on job %s", self.job) diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index ca382881..c91a4f92 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -2,7 +2,7 @@ from collections import Counter from logging import getLogger from queue import Empty from threading import Thread -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple from torch.multiprocessing import Process, Queue, Value @@ -29,7 +29,7 @@ class DevicePoolExecutor: threads: Dict[str, Thread] workers: Dict[str, Process] - active_jobs: Dict[str, ProgressCommand] # Device -> job progress + active_jobs: Dict[str, ProgressCommand] # Device -> job progress cancelled_jobs: List[str] finished_jobs: List[ProgressCommand] total_jobs: Dict[str, int] # Device -> job count @@ -149,11 +149,15 @@ class DevicePoolExecutor: del self.active_jobs[progress.job] self.join_leaking() else: - logger.debug("progress update for job: %s to %s", progress.job, progress.progress) + logger.debug( + "progress update for job: %s to %s", progress.job, progress.progress + ) self.active_jobs[progress.job] = progress if progress.job in self.cancelled_jobs: logger.debug( - "setting flag for cancelled job: %s on %s", progress.job, progress.device + "setting flag for cancelled job: %s on %s", + progress.job, + progress.device, ) self.context[progress.device].set_cancel() diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index 2d07289d..934d3bfa 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -2,7 +2,6 @@ from logging import getLogger from os import getpid from queue import Empty from sys import exit -from traceback import format_exception from setproctitle import setproctitle @@ -54,10 +53,8 @@ def worker_main(context: WorkerContext, server: ServerContext): logger.info("worker got keyboard interrupt") context.set_failed() exit(EXIT_INTERRUPT) - except ValueError as e: - logger.exception( - "value error in worker, exiting: %s" - ) + except ValueError: + logger.exception("value error in worker, exiting: %s") context.set_failed() exit(EXIT_ERROR) except Exception as e: