feat(api): add error flag to image ready response
This commit is contained in:
parent
0ab52f0c24
commit
7cf5554bef
|
@ -50,9 +50,11 @@ from .utils import wrap_route
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def ready_reply(ready: bool, progress: int = 0):
|
||||
def ready_reply(ready: bool, progress: int = 0, error: bool = False, cancel: bool = False):
|
||||
return jsonify(
|
||||
{
|
||||
"cancel": cancel,
|
||||
"error": error,
|
||||
"progress": progress,
|
||||
"ready": ready,
|
||||
}
|
||||
|
@ -437,7 +439,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
|||
output_file = sanitize_name(output_file)
|
||||
cancel = pool.cancel(output_file)
|
||||
|
||||
return ready_reply(cancel)
|
||||
return ready_reply(cancel == False, cancel=cancel)
|
||||
|
||||
|
||||
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||
|
@ -446,14 +448,14 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
|||
return error_reply("output name is required")
|
||||
|
||||
output_file = sanitize_name(output_file)
|
||||
done, progress = pool.done(output_file)
|
||||
progress = pool.done(output_file)
|
||||
|
||||
if done is None:
|
||||
if progress is None:
|
||||
output = base_join(context.output_path, output_file)
|
||||
if path.exists(output):
|
||||
return ready_reply(True)
|
||||
|
||||
return ready_reply(done or False, progress=progress)
|
||||
return ready_reply(progress.finished, progress=progress.progress, error=progress.error, cancel=progress.cancel)
|
||||
|
||||
|
||||
def status(context: ServerContext, pool: DevicePoolExecutor):
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
from typing import Callable, Any
|
||||
|
||||
class ProgressCommand():
|
||||
device: str
|
||||
job: str
|
||||
finished: bool
|
||||
progress: int
|
||||
cancel: bool
|
||||
error: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
job: str,
|
||||
device: str,
|
||||
finished: bool,
|
||||
progress: int,
|
||||
cancel: bool = False,
|
||||
error: bool = False,
|
||||
):
|
||||
self.job = job
|
||||
self.device = device
|
||||
self.finished = finished
|
||||
self.progress = progress
|
||||
self.cancel = cancel
|
||||
self.error = error
|
||||
|
||||
class JobCommand():
|
||||
name: str
|
||||
fn: Callable[..., None]
|
||||
args: Any
|
||||
kwargs: dict[str, Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
fn: Callable[..., None],
|
||||
args: Any,
|
||||
kwargs: dict[str, Any],
|
||||
):
|
||||
self.name = name
|
||||
self.fn = fn
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Tuple
|
|||
|
||||
from torch.multiprocessing import Queue, Value
|
||||
|
||||
from .command import JobCommand, ProgressCommand
|
||||
from ..params import DeviceParams
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -15,9 +16,9 @@ ProgressCallback = Callable[[int, int, Any], None]
|
|||
class WorkerContext:
|
||||
cancel: "Value[bool]"
|
||||
job: str
|
||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
|
||||
pending: "Queue[JobCommand]"
|
||||
current: "Value[int]"
|
||||
progress: "Queue[Tuple[str, str, int]]"
|
||||
progress: "Queue[ProgressCommand]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -25,16 +26,14 @@ class WorkerContext:
|
|||
device: DeviceParams,
|
||||
cancel: "Value[bool]",
|
||||
logs: "Queue[str]",
|
||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
|
||||
progress: "Queue[Tuple[str, str, int]]",
|
||||
finished: "Queue[Tuple[str, str]]",
|
||||
pending: "Queue[JobCommand]",
|
||||
progress: "Queue[ProgressCommand]",
|
||||
current: "Value[int]",
|
||||
):
|
||||
self.job = job
|
||||
self.device = device
|
||||
self.cancel = cancel
|
||||
self.progress = progress
|
||||
self.finished = finished
|
||||
self.logs = logs
|
||||
self.pending = pending
|
||||
self.current = current
|
||||
|
@ -61,10 +60,6 @@ class WorkerContext:
|
|||
def get_progress_callback(self) -> ProgressCallback:
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
on_progress.step = step
|
||||
if self.is_cancelled():
|
||||
raise RuntimeError("job has been cancelled")
|
||||
else:
|
||||
logger.debug("setting progress for job %s to %s", self.job, step)
|
||||
self.set_progress(step)
|
||||
|
||||
return on_progress
|
||||
|
@ -74,14 +69,22 @@ class WorkerContext:
|
|||
self.cancel.value = cancel
|
||||
|
||||
def set_progress(self, progress: int) -> None:
|
||||
self.progress.put((self.job, self.device.device, progress), block=False)
|
||||
if self.is_cancelled():
|
||||
raise RuntimeError("job has been cancelled")
|
||||
else:
|
||||
logger.debug("setting progress for job %s to %s", self.job, progress)
|
||||
self.progress.put(ProgressCommand(self.job, self.device.device, False, progress, self.is_cancelled(), False), block=False)
|
||||
|
||||
def set_finished(self) -> None:
|
||||
self.finished.put((self.job, self.device.device), block=False)
|
||||
logger.debug("setting finished for job %s", self.job)
|
||||
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), False), block=False)
|
||||
|
||||
def clear_flags(self) -> None:
|
||||
self.set_cancel(False)
|
||||
self.set_progress(0)
|
||||
def set_failed(self) -> None:
|
||||
logger.warning("setting failure for job %s", self.job)
|
||||
try:
|
||||
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), True), block=False)
|
||||
except:
|
||||
logger.exception("error setting failure on job %s", self.job)
|
||||
|
||||
|
||||
class JobStatus:
|
||||
|
|
|
@ -8,6 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
|
|||
|
||||
from ..params import DeviceParams
|
||||
from ..server import ServerContext
|
||||
from .command import JobCommand, ProgressCommand
|
||||
from .context import WorkerContext
|
||||
from .worker import worker_main
|
||||
|
||||
|
@ -24,18 +25,17 @@ class DevicePoolExecutor:
|
|||
leaking: List[Tuple[str, Process]]
|
||||
context: Dict[str, WorkerContext] # Device -> Context
|
||||
current: Dict[str, "Value[int]"]
|
||||
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
|
||||
pending: Dict[str, "Queue[JobCommand]"]
|
||||
threads: Dict[str, Thread]
|
||||
workers: Dict[str, Process]
|
||||
|
||||
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
|
||||
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
|
||||
cancelled_jobs: List[str]
|
||||
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
|
||||
finished_jobs: List[ProgressCommand]
|
||||
total_jobs: Dict[str, int] # Device -> job count
|
||||
|
||||
logs: "Queue"
|
||||
progress: "Queue[Tuple[str, str, int]]"
|
||||
finished: "Queue[Tuple[str, str]]"
|
||||
logs: "Queue[str]"
|
||||
progress: "Queue[ProgressCommand]"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -142,18 +142,27 @@ class DevicePoolExecutor:
|
|||
logger_thread.start()
|
||||
|
||||
def create_progress_worker(self) -> None:
|
||||
def progress_worker(progress: Queue):
|
||||
def update_job(progress: ProgressCommand):
|
||||
if progress.finished:
|
||||
logger.info("job has finished: %s", progress.job)
|
||||
self.finished_jobs.append(progress)
|
||||
del self.active_jobs[progress.job]
|
||||
self.join_leaking()
|
||||
else:
|
||||
logger.debug("progress update for job: %s to %s", progress.job, progress.progress)
|
||||
self.active_jobs[progress.job] = progress
|
||||
if progress.job in self.cancelled_jobs:
|
||||
logger.debug(
|
||||
"setting flag for cancelled job: %s on %s", progress.job, progress.device
|
||||
)
|
||||
self.context[progress.device].set_cancel()
|
||||
|
||||
def progress_worker(queue: "Queue[ProgressCommand]"):
|
||||
logger.trace("checking in from progress worker thread")
|
||||
while True:
|
||||
try:
|
||||
job, device, value = progress.get(timeout=(self.join_timeout / 2))
|
||||
logger.debug("progress update for job: %s to %s", job, value)
|
||||
self.active_jobs[job] = (device, value)
|
||||
if job in self.cancelled_jobs:
|
||||
logger.debug(
|
||||
"setting flag for cancelled job: %s on %s", job, device
|
||||
)
|
||||
self.context[device].set_cancel()
|
||||
progress = queue.get(timeout=(self.join_timeout / 2))
|
||||
update_job(progress)
|
||||
except Empty:
|
||||
pass
|
||||
except ValueError:
|
||||
|
@ -178,12 +187,7 @@ class DevicePoolExecutor:
|
|||
while True:
|
||||
try:
|
||||
job, device = finished.get(timeout=(self.join_timeout / 2))
|
||||
logger.info("job has been finished: %s", job)
|
||||
context = self.context[device]
|
||||
_device, progress = self.active_jobs[job]
|
||||
self.finished_jobs.append((job, progress, context.cancel.value))
|
||||
del self.active_jobs[job]
|
||||
self.join_leaking()
|
||||
|
||||
except Empty:
|
||||
pass
|
||||
except ValueError:
|
||||
|
@ -232,37 +236,36 @@ class DevicePoolExecutor:
|
|||
should be cancelled on the next progress callback.
|
||||
"""
|
||||
|
||||
self.cancelled_jobs.append(key)
|
||||
for job in self.finished_jobs:
|
||||
if job.job == key:
|
||||
logger.debug("cannot cancel finished job: %s", key)
|
||||
return False
|
||||
|
||||
if key not in self.active_jobs:
|
||||
logger.debug("cancelled job has not been started yet: %s", key)
|
||||
logger.debug("cancelled job is not active: %s", key)
|
||||
else:
|
||||
job = self.active_jobs[key]
|
||||
logger.info("cancelling job %s, active on device %s", key, job.device)
|
||||
|
||||
self.cancelled_jobs.append(key)
|
||||
return True
|
||||
|
||||
device, _progress = self.active_jobs[key]
|
||||
logger.info("cancelling job %s, active on device %s", key, device)
|
||||
|
||||
context = self.context[device]
|
||||
context.set_cancel()
|
||||
|
||||
return True
|
||||
|
||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||
def done(self, key: str) -> Optional[ProgressCommand]:
|
||||
"""
|
||||
Check if a job has been finished and report the last progress update.
|
||||
|
||||
If the job is still active or pending, the first item will be False.
|
||||
If the job is not finished or active, the first item will be None.
|
||||
"""
|
||||
for k, p, c in self.finished_jobs:
|
||||
if k == key:
|
||||
return (True, p)
|
||||
for job in self.finished_jobs:
|
||||
if job.job == key:
|
||||
return job
|
||||
|
||||
if key not in self.active_jobs:
|
||||
logger.debug("checking status for unknown job: %s", key)
|
||||
return (None, 0)
|
||||
return None
|
||||
|
||||
_device, progress = self.active_jobs[key]
|
||||
return (False, progress)
|
||||
return self.active_jobs[key]
|
||||
|
||||
def join(self):
|
||||
logger.info("stopping worker pool")
|
||||
|
@ -387,22 +390,29 @@ class DevicePoolExecutor:
|
|||
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
||||
self.recycle()
|
||||
|
||||
self.pending[device].put((key, fn, args, kwargs), block=False)
|
||||
self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)
|
||||
|
||||
def status(self) -> List[Tuple[str, int, bool, bool]]:
|
||||
def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
|
||||
history = [
|
||||
(name, progress, False, name in self.cancelled_jobs)
|
||||
for name, (_device, progress) in self.active_jobs.items()
|
||||
(
|
||||
name,
|
||||
job.progress,
|
||||
job.finished,
|
||||
job.cancel,
|
||||
job.error,
|
||||
)
|
||||
for name, job in self.active_jobs.items()
|
||||
]
|
||||
history.extend(
|
||||
[
|
||||
(
|
||||
name,
|
||||
progress,
|
||||
True,
|
||||
cancel,
|
||||
job.job,
|
||||
job.progress,
|
||||
job.finished,
|
||||
job.cancel,
|
||||
job.error,
|
||||
)
|
||||
for name, progress, cancel in self.finished_jobs
|
||||
for job in self.finished_jobs
|
||||
]
|
||||
)
|
||||
return history
|
||||
|
|
|
@ -39,32 +39,36 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
|||
)
|
||||
exit(EXIT_REPLACED)
|
||||
|
||||
name, fn, args, kwargs = context.pending.get(timeout=1.0)
|
||||
logger.info("worker for %s got job: %s", context.device.device, name)
|
||||
job = context.pending.get(timeout=1.0)
|
||||
logger.info("worker for %s got job: %s", context.device.device, job.name)
|
||||
|
||||
context.job = name # TODO: hax
|
||||
context.clear_flags()
|
||||
logger.info("starting job: %s", name)
|
||||
fn(context, *args, **kwargs)
|
||||
logger.info("job succeeded: %s", name)
|
||||
context.job = job.name # TODO: hax
|
||||
logger.info("starting job: %s", job.name)
|
||||
context.set_progress(0)
|
||||
job.fn(context, *job.args, **job.kwargs)
|
||||
logger.info("job succeeded: %s", job.name)
|
||||
context.set_finished()
|
||||
except Empty:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
logger.info("worker got keyboard interrupt")
|
||||
context.set_failed()
|
||||
exit(EXIT_INTERRUPT)
|
||||
except ValueError as e:
|
||||
logger.info(
|
||||
"value error in worker, exiting: %s",
|
||||
format_exception(type(e), e, e.__traceback__),
|
||||
logger.exception(
|
||||
"value error in worker, exiting: %s"
|
||||
)
|
||||
context.set_failed()
|
||||
exit(EXIT_ERROR)
|
||||
except Exception as e:
|
||||
e_str = str(e)
|
||||
if "Failed to allocate memory" in e_str or "out of memory" in e_str:
|
||||
logger.error("detected out-of-memory error, exiting: %s", e)
|
||||
context.set_failed()
|
||||
exit(EXIT_MEMORY)
|
||||
else:
|
||||
logger.exception(
|
||||
"error while running job",
|
||||
)
|
||||
context.set_failed()
|
||||
# carry on
|
||||
|
|
|
@ -175,6 +175,8 @@ export interface ImageResponse {
|
|||
* Status response from the ready endpoint.
|
||||
*/
|
||||
export interface ReadyResponse {
|
||||
cancel: boolean;
|
||||
error: boolean;
|
||||
progress: number;
|
||||
ready: boolean;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue