1
0
Fork 0

feat(api): add error flag to image ready response

This commit is contained in:
Sean Sube 2023-03-18 15:12:09 -05:00
parent 0ab52f0c24
commit 7cf5554bef
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 141 additions and 77 deletions

View File

@ -50,9 +50,11 @@ from .utils import wrap_route
logger = getLogger(__name__) logger = getLogger(__name__)
def ready_reply(ready: bool, progress: int = 0): def ready_reply(ready: bool, progress: int = 0, error: bool = False, cancel: bool = False):
return jsonify( return jsonify(
{ {
"cancel": cancel,
"error": error,
"progress": progress, "progress": progress,
"ready": ready, "ready": ready,
} }
@ -437,7 +439,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
output_file = sanitize_name(output_file) output_file = sanitize_name(output_file)
cancel = pool.cancel(output_file) cancel = pool.cancel(output_file)
return ready_reply(cancel) return ready_reply(cancel == False, cancel=cancel)
def ready(context: ServerContext, pool: DevicePoolExecutor): def ready(context: ServerContext, pool: DevicePoolExecutor):
@ -446,14 +448,14 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
return error_reply("output name is required") return error_reply("output name is required")
output_file = sanitize_name(output_file) output_file = sanitize_name(output_file)
done, progress = pool.done(output_file) progress = pool.done(output_file)
if done is None: if progress is None:
output = base_join(context.output_path, output_file) output = base_join(context.output_path, output_file)
if path.exists(output): if path.exists(output):
return ready_reply(True) return ready_reply(True)
return ready_reply(done or False, progress=progress) return ready_reply(progress.finished, progress=progress.progress, error=progress.error, cancel=progress.cancel)
def status(context: ServerContext, pool: DevicePoolExecutor): def status(context: ServerContext, pool: DevicePoolExecutor):

View File

@ -0,0 +1,43 @@
from typing import Callable, Any
class ProgressCommand():
device: str
job: str
finished: bool
progress: int
cancel: bool
error: bool
def __init__(
self,
job: str,
device: str,
finished: bool,
progress: int,
cancel: bool = False,
error: bool = False,
):
self.job = job
self.device = device
self.finished = finished
self.progress = progress
self.cancel = cancel
self.error = error
class JobCommand():
name: str
fn: Callable[..., None]
args: Any
kwargs: dict[str, Any]
def __init__(
self,
name: str,
fn: Callable[..., None],
args: Any,
kwargs: dict[str, Any],
):
self.name = name
self.fn = fn
self.args = args
self.kwargs = kwargs

View File

@ -4,6 +4,7 @@ from typing import Any, Callable, Tuple
from torch.multiprocessing import Queue, Value from torch.multiprocessing import Queue, Value
from .command import JobCommand, ProgressCommand
from ..params import DeviceParams from ..params import DeviceParams
logger = getLogger(__name__) logger = getLogger(__name__)
@ -15,9 +16,9 @@ ProgressCallback = Callable[[int, int, Any], None]
class WorkerContext: class WorkerContext:
cancel: "Value[bool]" cancel: "Value[bool]"
job: str job: str
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]" pending: "Queue[JobCommand]"
current: "Value[int]" current: "Value[int]"
progress: "Queue[Tuple[str, str, int]]" progress: "Queue[ProgressCommand]"
def __init__( def __init__(
self, self,
@ -25,16 +26,14 @@ class WorkerContext:
device: DeviceParams, device: DeviceParams,
cancel: "Value[bool]", cancel: "Value[bool]",
logs: "Queue[str]", logs: "Queue[str]",
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]", pending: "Queue[JobCommand]",
progress: "Queue[Tuple[str, str, int]]", progress: "Queue[ProgressCommand]",
finished: "Queue[Tuple[str, str]]",
current: "Value[int]", current: "Value[int]",
): ):
self.job = job self.job = job
self.device = device self.device = device
self.cancel = cancel self.cancel = cancel
self.progress = progress self.progress = progress
self.finished = finished
self.logs = logs self.logs = logs
self.pending = pending self.pending = pending
self.current = current self.current = current
@ -61,11 +60,7 @@ class WorkerContext:
def get_progress_callback(self) -> ProgressCallback: def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step on_progress.step = step
if self.is_cancelled(): self.set_progress(step)
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.job, step)
self.set_progress(step)
return on_progress return on_progress
@ -74,14 +69,22 @@ class WorkerContext:
self.cancel.value = cancel self.cancel.value = cancel
def set_progress(self, progress: int) -> None: def set_progress(self, progress: int) -> None:
self.progress.put((self.job, self.device.device, progress), block=False) if self.is_cancelled():
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.job, progress)
self.progress.put(ProgressCommand(self.job, self.device.device, False, progress, self.is_cancelled(), False), block=False)
def set_finished(self) -> None: def set_finished(self) -> None:
self.finished.put((self.job, self.device.device), block=False) logger.debug("setting finished for job %s", self.job)
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), False), block=False)
def clear_flags(self) -> None: def set_failed(self) -> None:
self.set_cancel(False) logger.warning("setting failure for job %s", self.job)
self.set_progress(0) try:
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), True), block=False)
except:
logger.exception("error setting failure on job %s", self.job)
class JobStatus: class JobStatus:

View File

@ -8,6 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
from ..params import DeviceParams from ..params import DeviceParams
from ..server import ServerContext from ..server import ServerContext
from .command import JobCommand, ProgressCommand
from .context import WorkerContext from .context import WorkerContext
from .worker import worker_main from .worker import worker_main
@ -24,18 +25,17 @@ class DevicePoolExecutor:
leaking: List[Tuple[str, Process]] leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"] current: Dict[str, "Value[int]"]
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"] pending: Dict[str, "Queue[JobCommand]"]
threads: Dict[str, Thread] threads: Dict[str, Thread]
workers: Dict[str, Process] workers: Dict[str, Process]
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus] active_jobs: Dict[str, ProgressCommand] # Device -> job progress
cancelled_jobs: List[str] cancelled_jobs: List[str]
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus] finished_jobs: List[ProgressCommand]
total_jobs: Dict[str, int] # Device -> job count total_jobs: Dict[str, int] # Device -> job count
logs: "Queue" logs: "Queue[str]"
progress: "Queue[Tuple[str, str, int]]" progress: "Queue[ProgressCommand]"
finished: "Queue[Tuple[str, str]]"
def __init__( def __init__(
self, self,
@ -142,18 +142,27 @@ class DevicePoolExecutor:
logger_thread.start() logger_thread.start()
def create_progress_worker(self) -> None: def create_progress_worker(self) -> None:
def progress_worker(progress: Queue): def update_job(progress: ProgressCommand):
if progress.finished:
logger.info("job has finished: %s", progress.job)
self.finished_jobs.append(progress)
del self.active_jobs[progress.job]
self.join_leaking()
else:
logger.debug("progress update for job: %s to %s", progress.job, progress.progress)
self.active_jobs[progress.job] = progress
if progress.job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s", progress.job, progress.device
)
self.context[progress.device].set_cancel()
def progress_worker(queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread") logger.trace("checking in from progress worker thread")
while True: while True:
try: try:
job, device, value = progress.get(timeout=(self.join_timeout / 2)) progress = queue.get(timeout=(self.join_timeout / 2))
logger.debug("progress update for job: %s to %s", job, value) update_job(progress)
self.active_jobs[job] = (device, value)
if job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s", job, device
)
self.context[device].set_cancel()
except Empty: except Empty:
pass pass
except ValueError: except ValueError:
@ -178,12 +187,7 @@ class DevicePoolExecutor:
while True: while True:
try: try:
job, device = finished.get(timeout=(self.join_timeout / 2)) job, device = finished.get(timeout=(self.join_timeout / 2))
logger.info("job has been finished: %s", job)
context = self.context[device]
_device, progress = self.active_jobs[job]
self.finished_jobs.append((job, progress, context.cancel.value))
del self.active_jobs[job]
self.join_leaking()
except Empty: except Empty:
pass pass
except ValueError: except ValueError:
@ -232,37 +236,36 @@ class DevicePoolExecutor:
should be cancelled on the next progress callback. should be cancelled on the next progress callback.
""" """
self.cancelled_jobs.append(key) for job in self.finished_jobs:
if job.job == key:
logger.debug("cannot cancel finished job: %s", key)
return False
if key not in self.active_jobs: if key not in self.active_jobs:
logger.debug("cancelled job has not been started yet: %s", key) logger.debug("cancelled job is not active: %s", key)
return True else:
job = self.active_jobs[key]
device, _progress = self.active_jobs[key] logger.info("cancelling job %s, active on device %s", key, job.device)
logger.info("cancelling job %s, active on device %s", key, device)
context = self.context[device]
context.set_cancel()
self.cancelled_jobs.append(key)
return True return True
def done(self, key: str) -> Tuple[Optional[bool], int]: def done(self, key: str) -> Optional[ProgressCommand]:
""" """
Check if a job has been finished and report the last progress update. Check if a job has been finished and report the last progress update.
If the job is still active or pending, the first item will be False. If the job is still active or pending, the first item will be False.
If the job is not finished or active, the first item will be None. If the job is not finished or active, the first item will be None.
""" """
for k, p, c in self.finished_jobs: for job in self.finished_jobs:
if k == key: if job.job == key:
return (True, p) return job
if key not in self.active_jobs: if key not in self.active_jobs:
logger.debug("checking status for unknown job: %s", key) logger.debug("checking status for unknown job: %s", key)
return (None, 0) return None
_device, progress = self.active_jobs[key] return self.active_jobs[key]
return (False, progress)
def join(self): def join(self):
logger.info("stopping worker pool") logger.info("stopping worker pool")
@ -387,22 +390,29 @@ class DevicePoolExecutor:
logger.debug("job count for device %s: %s", device, self.total_jobs[device]) logger.debug("job count for device %s: %s", device, self.total_jobs[device])
self.recycle() self.recycle()
self.pending[device].put((key, fn, args, kwargs), block=False) self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)
def status(self) -> List[Tuple[str, int, bool, bool]]: def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
history = [ history = [
(name, progress, False, name in self.cancelled_jobs) (
for name, (_device, progress) in self.active_jobs.items() name,
job.progress,
job.finished,
job.cancel,
job.error,
)
for name, job in self.active_jobs.items()
] ]
history.extend( history.extend(
[ [
( (
name, job.job,
progress, job.progress,
True, job.finished,
cancel, job.cancel,
job.error,
) )
for name, progress, cancel in self.finished_jobs for job in self.finished_jobs
] ]
) )
return history return history

View File

@ -39,32 +39,36 @@ def worker_main(context: WorkerContext, server: ServerContext):
) )
exit(EXIT_REPLACED) exit(EXIT_REPLACED)
name, fn, args, kwargs = context.pending.get(timeout=1.0) job = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, name) logger.info("worker for %s got job: %s", context.device.device, job.name)
context.job = name # TODO: hax context.job = job.name # TODO: hax
context.clear_flags() logger.info("starting job: %s", job.name)
logger.info("starting job: %s", name) context.set_progress(0)
fn(context, *args, **kwargs) job.fn(context, *job.args, **job.kwargs)
logger.info("job succeeded: %s", name) logger.info("job succeeded: %s", job.name)
context.set_finished() context.set_finished()
except Empty: except Empty:
pass pass
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("worker got keyboard interrupt") logger.info("worker got keyboard interrupt")
context.set_failed()
exit(EXIT_INTERRUPT) exit(EXIT_INTERRUPT)
except ValueError as e: except ValueError as e:
logger.info( logger.exception(
"value error in worker, exiting: %s", "value error in worker, exiting: %s"
format_exception(type(e), e, e.__traceback__),
) )
context.set_failed()
exit(EXIT_ERROR) exit(EXIT_ERROR)
except Exception as e: except Exception as e:
e_str = str(e) e_str = str(e)
if "Failed to allocate memory" in e_str or "out of memory" in e_str: if "Failed to allocate memory" in e_str or "out of memory" in e_str:
logger.error("detected out-of-memory error, exiting: %s", e) logger.error("detected out-of-memory error, exiting: %s", e)
context.set_failed()
exit(EXIT_MEMORY) exit(EXIT_MEMORY)
else: else:
logger.exception( logger.exception(
"error while running job", "error while running job",
) )
context.set_failed()
# carry on

View File

@ -175,6 +175,8 @@ export interface ImageResponse {
* Status response from the ready endpoint. * Status response from the ready endpoint.
*/ */
export interface ReadyResponse { export interface ReadyResponse {
cancel: boolean;
error: boolean;
progress: number; progress: number;
ready: boolean; ready: boolean;
} }