apply lint, make missing images an error
This commit is contained in:
parent
7cf5554bef
commit
d1565b056e
|
@ -50,7 +50,9 @@ from .utils import wrap_route
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def ready_reply(ready: bool, progress: int = 0, error: bool = False, cancel: bool = False):
|
||||
def ready_reply(
|
||||
ready: bool, progress: int = 0, error: bool = False, cancel: bool = False
|
||||
):
|
||||
return jsonify(
|
||||
{
|
||||
"cancel": cancel,
|
||||
|
@ -439,7 +441,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
|||
output_file = sanitize_name(output_file)
|
||||
cancel = pool.cancel(output_file)
|
||||
|
||||
return ready_reply(cancel == False, cancel=cancel)
|
||||
return ready_reply(cancel is not False, cancel=cancel)
|
||||
|
||||
|
||||
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -454,8 +456,15 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
|||
output = base_join(context.output_path, output_file)
|
||||
if path.exists(output):
|
||||
return ready_reply(True)
|
||||
else:
|
||||
return ready_reply(True, error=True) # is a missing image really an error? yes will display the retry button
|
||||
|
||||
return ready_reply(progress.finished, progress=progress.progress, error=progress.error, cancel=progress.cancel)
|
||||
return ready_reply(
|
||||
progress.finished,
|
||||
progress=progress.progress,
|
||||
error=progress.error,
|
||||
cancel=progress.cancel,
|
||||
)
|
||||
|
||||
|
||||
def status(context: ServerContext, pool: DevicePoolExecutor):
|
||||
|
|
|
@ -1,43 +1,45 @@
|
|||
from typing import Callable, Any
|
||||
from typing import Any, Callable
|
||||
|
||||
class ProgressCommand():
|
||||
device: str
|
||||
job: str
|
||||
finished: bool
|
||||
progress: int
|
||||
cancel: bool
|
||||
error: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job: str,
|
||||
device: str,
|
||||
finished: bool,
|
||||
progress: int,
|
||||
cancel: bool = False,
|
||||
error: bool = False,
|
||||
):
|
||||
self.job = job
|
||||
self.device = device
|
||||
self.finished = finished
|
||||
self.progress = progress
|
||||
self.cancel = cancel
|
||||
self.error = error
|
||||
class ProgressCommand:
|
||||
device: str
|
||||
job: str
|
||||
finished: bool
|
||||
progress: int
|
||||
cancel: bool
|
||||
error: bool
|
||||
|
||||
class JobCommand():
|
||||
name: str
|
||||
fn: Callable[..., None]
|
||||
args: Any
|
||||
kwargs: dict[str, Any]
|
||||
def __init__(
|
||||
self,
|
||||
job: str,
|
||||
device: str,
|
||||
finished: bool,
|
||||
progress: int,
|
||||
cancel: bool = False,
|
||||
error: bool = False,
|
||||
):
|
||||
self.job = job
|
||||
self.device = device
|
||||
self.finished = finished
|
||||
self.progress = progress
|
||||
self.cancel = cancel
|
||||
self.error = error
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
fn: Callable[..., None],
|
||||
args: Any,
|
||||
kwargs: dict[str, Any],
|
||||
):
|
||||
self.name = name
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class JobCommand:
|
||||
name: str
|
||||
fn: Callable[..., None]
|
||||
args: Any
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
fn: Callable[..., None],
|
||||
args: Any,
|
||||
kwargs: dict[str, Any],
|
||||
):
|
||||
self.name = name
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
from logging import getLogger
|
||||
from os import getpid
|
||||
from typing import Any, Callable, Tuple
|
||||
from typing import Any, Callable
|
||||
|
||||
from torch.multiprocessing import Queue, Value
|
||||
|
||||
from .command import JobCommand, ProgressCommand
|
||||
from ..params import DeviceParams
|
||||
from .command import JobCommand, ProgressCommand
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -73,17 +73,47 @@ class WorkerContext:
|
|||
raise RuntimeError("job has been cancelled")
|
||||
else:
|
||||
logger.debug("setting progress for job %s to %s", self.job, progress)
|
||||
self.progress.put(ProgressCommand(self.job, self.device.device, False, progress, self.is_cancelled(), False), block=False)
|
||||
self.progress.put(
|
||||
ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
False,
|
||||
progress,
|
||||
self.is_cancelled(),
|
||||
False,
|
||||
),
|
||||
block=False,
|
||||
)
|
||||
|
||||
def set_finished(self) -> None:
|
||||
logger.debug("setting finished for job %s", self.job)
|
||||
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), False), block=False)
|
||||
self.progress.put(
|
||||
ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
True,
|
||||
self.get_progress(),
|
||||
self.is_cancelled(),
|
||||
False,
|
||||
),
|
||||
block=False,
|
||||
)
|
||||
|
||||
def set_failed(self) -> None:
|
||||
logger.warning("setting failure for job %s", self.job)
|
||||
try:
|
||||
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), True), block=False)
|
||||
except:
|
||||
self.progress.put(
|
||||
ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
True,
|
||||
self.get_progress(),
|
||||
self.is_cancelled(),
|
||||
True,
|
||||
),
|
||||
block=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("error setting failure on job %s", self.job)
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ from collections import Counter
|
|||
from logging import getLogger
|
||||
from queue import Empty
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch.multiprocessing import Process, Queue, Value
|
||||
|
||||
|
@ -29,7 +29,7 @@ class DevicePoolExecutor:
|
|||
threads: Dict[str, Thread]
|
||||
workers: Dict[str, Process]
|
||||
|
||||
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
|
||||
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
|
||||
cancelled_jobs: List[str]
|
||||
finished_jobs: List[ProgressCommand]
|
||||
total_jobs: Dict[str, int] # Device -> job count
|
||||
|
@ -149,11 +149,15 @@ class DevicePoolExecutor:
|
|||
del self.active_jobs[progress.job]
|
||||
self.join_leaking()
|
||||
else:
|
||||
logger.debug("progress update for job: %s to %s", progress.job, progress.progress)
|
||||
logger.debug(
|
||||
"progress update for job: %s to %s", progress.job, progress.progress
|
||||
)
|
||||
self.active_jobs[progress.job] = progress
|
||||
if progress.job in self.cancelled_jobs:
|
||||
logger.debug(
|
||||
"setting flag for cancelled job: %s on %s", progress.job, progress.device
|
||||
"setting flag for cancelled job: %s on %s",
|
||||
progress.job,
|
||||
progress.device,
|
||||
)
|
||||
self.context[progress.device].set_cancel()
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ from logging import getLogger
|
|||
from os import getpid
|
||||
from queue import Empty
|
||||
from sys import exit
|
||||
from traceback import format_exception
|
||||
|
||||
from setproctitle import setproctitle
|
||||
|
||||
|
@ -54,10 +53,8 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
|||
logger.info("worker got keyboard interrupt")
|
||||
context.set_failed()
|
||||
exit(EXIT_INTERRUPT)
|
||||
except ValueError as e:
|
||||
logger.exception(
|
||||
"value error in worker, exiting: %s"
|
||||
)
|
||||
except ValueError:
|
||||
logger.exception("value error in worker, exiting: %s")
|
||||
context.set_failed()
|
||||
exit(EXIT_ERROR)
|
||||
except Exception as e:
|
||||
|
|
Loading…
Reference in New Issue