1
0
Fork 0

feat(api): add error flag to image ready response

This commit is contained in:
Sean Sube 2023-03-18 15:12:09 -05:00
parent 0ab52f0c24
commit 7cf5554bef
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
6 changed files with 141 additions and 77 deletions

View File

@ -50,9 +50,11 @@ from .utils import wrap_route
logger = getLogger(__name__)
def ready_reply(ready: bool, progress: int = 0):
def ready_reply(ready: bool, progress: int = 0, error: bool = False, cancel: bool = False):
return jsonify(
{
"cancel": cancel,
"error": error,
"progress": progress,
"ready": ready,
}
@ -437,7 +439,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
output_file = sanitize_name(output_file)
cancel = pool.cancel(output_file)
return ready_reply(cancel)
return ready_reply(cancel == False, cancel=cancel)
def ready(context: ServerContext, pool: DevicePoolExecutor):
@ -446,14 +448,14 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
return error_reply("output name is required")
output_file = sanitize_name(output_file)
done, progress = pool.done(output_file)
progress = pool.done(output_file)
if done is None:
if progress is None:
output = base_join(context.output_path, output_file)
if path.exists(output):
return ready_reply(True)
return ready_reply(done or False, progress=progress)
return ready_reply(progress.finished, progress=progress.progress, error=progress.error, cancel=progress.cancel)
def status(context: ServerContext, pool: DevicePoolExecutor):

View File

@ -0,0 +1,43 @@
from typing import Callable, Any
class ProgressCommand():
device: str
job: str
finished: bool
progress: int
cancel: bool
error: bool
def __init__(
self,
job: str,
device: str,
finished: bool,
progress: int,
cancel: bool = False,
error: bool = False,
):
self.job = job
self.device = device
self.finished = finished
self.progress = progress
self.cancel = cancel
self.error = error
class JobCommand():
name: str
fn: Callable[..., None]
args: Any
kwargs: dict[str, Any]
def __init__(
self,
name: str,
fn: Callable[..., None],
args: Any,
kwargs: dict[str, Any],
):
self.name = name
self.fn = fn
self.args = args
self.kwargs = kwargs

View File

@ -4,6 +4,7 @@ from typing import Any, Callable, Tuple
from torch.multiprocessing import Queue, Value
from .command import JobCommand, ProgressCommand
from ..params import DeviceParams
logger = getLogger(__name__)
@ -15,9 +16,9 @@ ProgressCallback = Callable[[int, int, Any], None]
class WorkerContext:
cancel: "Value[bool]"
job: str
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
pending: "Queue[JobCommand]"
current: "Value[int]"
progress: "Queue[Tuple[str, str, int]]"
progress: "Queue[ProgressCommand]"
def __init__(
self,
@ -25,16 +26,14 @@ class WorkerContext:
device: DeviceParams,
cancel: "Value[bool]",
logs: "Queue[str]",
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
progress: "Queue[Tuple[str, str, int]]",
finished: "Queue[Tuple[str, str]]",
pending: "Queue[JobCommand]",
progress: "Queue[ProgressCommand]",
current: "Value[int]",
):
self.job = job
self.device = device
self.cancel = cancel
self.progress = progress
self.finished = finished
self.logs = logs
self.pending = pending
self.current = current
@ -61,10 +60,6 @@ class WorkerContext:
def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
if self.is_cancelled():
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.job, step)
self.set_progress(step)
return on_progress
@ -74,14 +69,22 @@ class WorkerContext:
self.cancel.value = cancel
def set_progress(self, progress: int) -> None:
self.progress.put((self.job, self.device.device, progress), block=False)
if self.is_cancelled():
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.job, progress)
self.progress.put(ProgressCommand(self.job, self.device.device, False, progress, self.is_cancelled(), False), block=False)
def set_finished(self) -> None:
self.finished.put((self.job, self.device.device), block=False)
logger.debug("setting finished for job %s", self.job)
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), False), block=False)
def clear_flags(self) -> None:
self.set_cancel(False)
self.set_progress(0)
def set_failed(self) -> None:
logger.warning("setting failure for job %s", self.job)
try:
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), True), block=False)
except:
logger.exception("error setting failure on job %s", self.job)
class JobStatus:

View File

@ -8,6 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
from ..params import DeviceParams
from ..server import ServerContext
from .command import JobCommand, ProgressCommand
from .context import WorkerContext
from .worker import worker_main
@ -24,18 +25,17 @@ class DevicePoolExecutor:
leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"]
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
pending: Dict[str, "Queue[JobCommand]"]
threads: Dict[str, Thread]
workers: Dict[str, Process]
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
cancelled_jobs: List[str]
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
finished_jobs: List[ProgressCommand]
total_jobs: Dict[str, int] # Device -> job count
logs: "Queue"
progress: "Queue[Tuple[str, str, int]]"
finished: "Queue[Tuple[str, str]]"
logs: "Queue[str]"
progress: "Queue[ProgressCommand]"
def __init__(
self,
@ -142,18 +142,27 @@ class DevicePoolExecutor:
logger_thread.start()
def create_progress_worker(self) -> None:
def progress_worker(progress: Queue):
def update_job(progress: ProgressCommand):
if progress.finished:
logger.info("job has finished: %s", progress.job)
self.finished_jobs.append(progress)
del self.active_jobs[progress.job]
self.join_leaking()
else:
logger.debug("progress update for job: %s to %s", progress.job, progress.progress)
self.active_jobs[progress.job] = progress
if progress.job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s", progress.job, progress.device
)
self.context[progress.device].set_cancel()
def progress_worker(queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread")
while True:
try:
job, device, value = progress.get(timeout=(self.join_timeout / 2))
logger.debug("progress update for job: %s to %s", job, value)
self.active_jobs[job] = (device, value)
if job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s", job, device
)
self.context[device].set_cancel()
progress = queue.get(timeout=(self.join_timeout / 2))
update_job(progress)
except Empty:
pass
except ValueError:
@ -178,12 +187,7 @@ class DevicePoolExecutor:
while True:
try:
job, device = finished.get(timeout=(self.join_timeout / 2))
logger.info("job has been finished: %s", job)
context = self.context[device]
_device, progress = self.active_jobs[job]
self.finished_jobs.append((job, progress, context.cancel.value))
del self.active_jobs[job]
self.join_leaking()
except Empty:
pass
except ValueError:
@ -232,37 +236,36 @@ class DevicePoolExecutor:
should be cancelled on the next progress callback.
"""
self.cancelled_jobs.append(key)
for job in self.finished_jobs:
if job.job == key:
logger.debug("cannot cancel finished job: %s", key)
return False
if key not in self.active_jobs:
logger.debug("cancelled job has not been started yet: %s", key)
logger.debug("cancelled job is not active: %s", key)
else:
job = self.active_jobs[key]
logger.info("cancelling job %s, active on device %s", key, job.device)
self.cancelled_jobs.append(key)
return True
device, _progress = self.active_jobs[key]
logger.info("cancelling job %s, active on device %s", key, device)
context = self.context[device]
context.set_cancel()
return True
def done(self, key: str) -> Tuple[Optional[bool], int]:
def done(self, key: str) -> Optional[ProgressCommand]:
"""
Check if a job has been finished and report the last progress update.
If the job is still active or pending, the first item will be False.
If the job is not finished or active, the first item will be None.
"""
for k, p, c in self.finished_jobs:
if k == key:
return (True, p)
for job in self.finished_jobs:
if job.job == key:
return job
if key not in self.active_jobs:
logger.debug("checking status for unknown job: %s", key)
return (None, 0)
return None
_device, progress = self.active_jobs[key]
return (False, progress)
return self.active_jobs[key]
def join(self):
logger.info("stopping worker pool")
@ -387,22 +390,29 @@ class DevicePoolExecutor:
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
self.recycle()
self.pending[device].put((key, fn, args, kwargs), block=False)
self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)
def status(self) -> List[Tuple[str, int, bool, bool]]:
def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
history = [
(name, progress, False, name in self.cancelled_jobs)
for name, (_device, progress) in self.active_jobs.items()
(
name,
job.progress,
job.finished,
job.cancel,
job.error,
)
for name, job in self.active_jobs.items()
]
history.extend(
[
(
name,
progress,
True,
cancel,
job.job,
job.progress,
job.finished,
job.cancel,
job.error,
)
for name, progress, cancel in self.finished_jobs
for job in self.finished_jobs
]
)
return history

View File

@ -39,32 +39,36 @@ def worker_main(context: WorkerContext, server: ServerContext):
)
exit(EXIT_REPLACED)
name, fn, args, kwargs = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, name)
job = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, job.name)
context.job = name # TODO: hax
context.clear_flags()
logger.info("starting job: %s", name)
fn(context, *args, **kwargs)
logger.info("job succeeded: %s", name)
context.job = job.name # TODO: hax
logger.info("starting job: %s", job.name)
context.set_progress(0)
job.fn(context, *job.args, **job.kwargs)
logger.info("job succeeded: %s", job.name)
context.set_finished()
except Empty:
pass
except KeyboardInterrupt:
logger.info("worker got keyboard interrupt")
context.set_failed()
exit(EXIT_INTERRUPT)
except ValueError as e:
logger.info(
"value error in worker, exiting: %s",
format_exception(type(e), e, e.__traceback__),
logger.exception(
"value error in worker, exiting: %s"
)
context.set_failed()
exit(EXIT_ERROR)
except Exception as e:
e_str = str(e)
if "Failed to allocate memory" in e_str or "out of memory" in e_str:
logger.error("detected out-of-memory error, exiting: %s", e)
context.set_failed()
exit(EXIT_MEMORY)
else:
logger.exception(
"error while running job",
)
context.set_failed()
# carry on

View File

@ -175,6 +175,8 @@ export interface ImageResponse {
* Status response from the ready endpoint.
*/
export interface ReadyResponse {
cancel: boolean;
error: boolean;
progress: number;
ready: boolean;
}