feat(api): add error flag to image ready response
This commit is contained in:
parent
0ab52f0c24
commit
7cf5554bef
|
@ -50,9 +50,11 @@ from .utils import wrap_route
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def ready_reply(ready: bool, progress: int = 0):
|
def ready_reply(ready: bool, progress: int = 0, error: bool = False, cancel: bool = False):
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
|
"cancel": cancel,
|
||||||
|
"error": error,
|
||||||
"progress": progress,
|
"progress": progress,
|
||||||
"ready": ready,
|
"ready": ready,
|
||||||
}
|
}
|
||||||
|
@ -437,7 +439,7 @@ def cancel(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
output_file = sanitize_name(output_file)
|
output_file = sanitize_name(output_file)
|
||||||
cancel = pool.cancel(output_file)
|
cancel = pool.cancel(output_file)
|
||||||
|
|
||||||
return ready_reply(cancel)
|
return ready_reply(cancel == False, cancel=cancel)
|
||||||
|
|
||||||
|
|
||||||
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
@ -446,14 +448,14 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
return error_reply("output name is required")
|
return error_reply("output name is required")
|
||||||
|
|
||||||
output_file = sanitize_name(output_file)
|
output_file = sanitize_name(output_file)
|
||||||
done, progress = pool.done(output_file)
|
progress = pool.done(output_file)
|
||||||
|
|
||||||
if done is None:
|
if progress is None:
|
||||||
output = base_join(context.output_path, output_file)
|
output = base_join(context.output_path, output_file)
|
||||||
if path.exists(output):
|
if path.exists(output):
|
||||||
return ready_reply(True)
|
return ready_reply(True)
|
||||||
|
|
||||||
return ready_reply(done or False, progress=progress)
|
return ready_reply(progress.finished, progress=progress.progress, error=progress.error, cancel=progress.cancel)
|
||||||
|
|
||||||
|
|
||||||
def status(context: ServerContext, pool: DevicePoolExecutor):
|
def status(context: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
from typing import Callable, Any
|
||||||
|
|
||||||
|
class ProgressCommand():
|
||||||
|
device: str
|
||||||
|
job: str
|
||||||
|
finished: bool
|
||||||
|
progress: int
|
||||||
|
cancel: bool
|
||||||
|
error: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
job: str,
|
||||||
|
device: str,
|
||||||
|
finished: bool,
|
||||||
|
progress: int,
|
||||||
|
cancel: bool = False,
|
||||||
|
error: bool = False,
|
||||||
|
):
|
||||||
|
self.job = job
|
||||||
|
self.device = device
|
||||||
|
self.finished = finished
|
||||||
|
self.progress = progress
|
||||||
|
self.cancel = cancel
|
||||||
|
self.error = error
|
||||||
|
|
||||||
|
class JobCommand():
|
||||||
|
name: str
|
||||||
|
fn: Callable[..., None]
|
||||||
|
args: Any
|
||||||
|
kwargs: dict[str, Any]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
fn: Callable[..., None],
|
||||||
|
args: Any,
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.fn = fn
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs
|
|
@ -4,6 +4,7 @@ from typing import Any, Callable, Tuple
|
||||||
|
|
||||||
from torch.multiprocessing import Queue, Value
|
from torch.multiprocessing import Queue, Value
|
||||||
|
|
||||||
|
from .command import JobCommand, ProgressCommand
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -15,9 +16,9 @@ ProgressCallback = Callable[[int, int, Any], None]
|
||||||
class WorkerContext:
|
class WorkerContext:
|
||||||
cancel: "Value[bool]"
|
cancel: "Value[bool]"
|
||||||
job: str
|
job: str
|
||||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
|
pending: "Queue[JobCommand]"
|
||||||
current: "Value[int]"
|
current: "Value[int]"
|
||||||
progress: "Queue[Tuple[str, str, int]]"
|
progress: "Queue[ProgressCommand]"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -25,16 +26,14 @@ class WorkerContext:
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
cancel: "Value[bool]",
|
cancel: "Value[bool]",
|
||||||
logs: "Queue[str]",
|
logs: "Queue[str]",
|
||||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
|
pending: "Queue[JobCommand]",
|
||||||
progress: "Queue[Tuple[str, str, int]]",
|
progress: "Queue[ProgressCommand]",
|
||||||
finished: "Queue[Tuple[str, str]]",
|
|
||||||
current: "Value[int]",
|
current: "Value[int]",
|
||||||
):
|
):
|
||||||
self.job = job
|
self.job = job
|
||||||
self.device = device
|
self.device = device
|
||||||
self.cancel = cancel
|
self.cancel = cancel
|
||||||
self.progress = progress
|
self.progress = progress
|
||||||
self.finished = finished
|
|
||||||
self.logs = logs
|
self.logs = logs
|
||||||
self.pending = pending
|
self.pending = pending
|
||||||
self.current = current
|
self.current = current
|
||||||
|
@ -61,10 +60,6 @@ class WorkerContext:
|
||||||
def get_progress_callback(self) -> ProgressCallback:
|
def get_progress_callback(self) -> ProgressCallback:
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
on_progress.step = step
|
on_progress.step = step
|
||||||
if self.is_cancelled():
|
|
||||||
raise RuntimeError("job has been cancelled")
|
|
||||||
else:
|
|
||||||
logger.debug("setting progress for job %s to %s", self.job, step)
|
|
||||||
self.set_progress(step)
|
self.set_progress(step)
|
||||||
|
|
||||||
return on_progress
|
return on_progress
|
||||||
|
@ -74,14 +69,22 @@ class WorkerContext:
|
||||||
self.cancel.value = cancel
|
self.cancel.value = cancel
|
||||||
|
|
||||||
def set_progress(self, progress: int) -> None:
|
def set_progress(self, progress: int) -> None:
|
||||||
self.progress.put((self.job, self.device.device, progress), block=False)
|
if self.is_cancelled():
|
||||||
|
raise RuntimeError("job has been cancelled")
|
||||||
|
else:
|
||||||
|
logger.debug("setting progress for job %s to %s", self.job, progress)
|
||||||
|
self.progress.put(ProgressCommand(self.job, self.device.device, False, progress, self.is_cancelled(), False), block=False)
|
||||||
|
|
||||||
def set_finished(self) -> None:
|
def set_finished(self) -> None:
|
||||||
self.finished.put((self.job, self.device.device), block=False)
|
logger.debug("setting finished for job %s", self.job)
|
||||||
|
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), False), block=False)
|
||||||
|
|
||||||
def clear_flags(self) -> None:
|
def set_failed(self) -> None:
|
||||||
self.set_cancel(False)
|
logger.warning("setting failure for job %s", self.job)
|
||||||
self.set_progress(0)
|
try:
|
||||||
|
self.progress.put(ProgressCommand(self.job, self.device.device, True, self.get_progress(), self.is_cancelled(), True), block=False)
|
||||||
|
except:
|
||||||
|
logger.exception("error setting failure on job %s", self.job)
|
||||||
|
|
||||||
|
|
||||||
class JobStatus:
|
class JobStatus:
|
||||||
|
|
|
@ -8,6 +8,7 @@ from torch.multiprocessing import Process, Queue, Value
|
||||||
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
|
from .command import JobCommand, ProgressCommand
|
||||||
from .context import WorkerContext
|
from .context import WorkerContext
|
||||||
from .worker import worker_main
|
from .worker import worker_main
|
||||||
|
|
||||||
|
@ -24,18 +25,17 @@ class DevicePoolExecutor:
|
||||||
leaking: List[Tuple[str, Process]]
|
leaking: List[Tuple[str, Process]]
|
||||||
context: Dict[str, WorkerContext] # Device -> Context
|
context: Dict[str, WorkerContext] # Device -> Context
|
||||||
current: Dict[str, "Value[int]"]
|
current: Dict[str, "Value[int]"]
|
||||||
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
|
pending: Dict[str, "Queue[JobCommand]"]
|
||||||
threads: Dict[str, Thread]
|
threads: Dict[str, Thread]
|
||||||
workers: Dict[str, Process]
|
workers: Dict[str, Process]
|
||||||
|
|
||||||
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
|
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
|
||||||
cancelled_jobs: List[str]
|
cancelled_jobs: List[str]
|
||||||
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
|
finished_jobs: List[ProgressCommand]
|
||||||
total_jobs: Dict[str, int] # Device -> job count
|
total_jobs: Dict[str, int] # Device -> job count
|
||||||
|
|
||||||
logs: "Queue"
|
logs: "Queue[str]"
|
||||||
progress: "Queue[Tuple[str, str, int]]"
|
progress: "Queue[ProgressCommand]"
|
||||||
finished: "Queue[Tuple[str, str]]"
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -142,18 +142,27 @@ class DevicePoolExecutor:
|
||||||
logger_thread.start()
|
logger_thread.start()
|
||||||
|
|
||||||
def create_progress_worker(self) -> None:
|
def create_progress_worker(self) -> None:
|
||||||
def progress_worker(progress: Queue):
|
def update_job(progress: ProgressCommand):
|
||||||
|
if progress.finished:
|
||||||
|
logger.info("job has finished: %s", progress.job)
|
||||||
|
self.finished_jobs.append(progress)
|
||||||
|
del self.active_jobs[progress.job]
|
||||||
|
self.join_leaking()
|
||||||
|
else:
|
||||||
|
logger.debug("progress update for job: %s to %s", progress.job, progress.progress)
|
||||||
|
self.active_jobs[progress.job] = progress
|
||||||
|
if progress.job in self.cancelled_jobs:
|
||||||
|
logger.debug(
|
||||||
|
"setting flag for cancelled job: %s on %s", progress.job, progress.device
|
||||||
|
)
|
||||||
|
self.context[progress.device].set_cancel()
|
||||||
|
|
||||||
|
def progress_worker(queue: "Queue[ProgressCommand]"):
|
||||||
logger.trace("checking in from progress worker thread")
|
logger.trace("checking in from progress worker thread")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
job, device, value = progress.get(timeout=(self.join_timeout / 2))
|
progress = queue.get(timeout=(self.join_timeout / 2))
|
||||||
logger.debug("progress update for job: %s to %s", job, value)
|
update_job(progress)
|
||||||
self.active_jobs[job] = (device, value)
|
|
||||||
if job in self.cancelled_jobs:
|
|
||||||
logger.debug(
|
|
||||||
"setting flag for cancelled job: %s on %s", job, device
|
|
||||||
)
|
|
||||||
self.context[device].set_cancel()
|
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -178,12 +187,7 @@ class DevicePoolExecutor:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
job, device = finished.get(timeout=(self.join_timeout / 2))
|
job, device = finished.get(timeout=(self.join_timeout / 2))
|
||||||
logger.info("job has been finished: %s", job)
|
|
||||||
context = self.context[device]
|
|
||||||
_device, progress = self.active_jobs[job]
|
|
||||||
self.finished_jobs.append((job, progress, context.cancel.value))
|
|
||||||
del self.active_jobs[job]
|
|
||||||
self.join_leaking()
|
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -232,37 +236,36 @@ class DevicePoolExecutor:
|
||||||
should be cancelled on the next progress callback.
|
should be cancelled on the next progress callback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.cancelled_jobs.append(key)
|
for job in self.finished_jobs:
|
||||||
|
if job.job == key:
|
||||||
|
logger.debug("cannot cancel finished job: %s", key)
|
||||||
|
return False
|
||||||
|
|
||||||
if key not in self.active_jobs:
|
if key not in self.active_jobs:
|
||||||
logger.debug("cancelled job has not been started yet: %s", key)
|
logger.debug("cancelled job is not active: %s", key)
|
||||||
|
else:
|
||||||
|
job = self.active_jobs[key]
|
||||||
|
logger.info("cancelling job %s, active on device %s", key, job.device)
|
||||||
|
|
||||||
|
self.cancelled_jobs.append(key)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
device, _progress = self.active_jobs[key]
|
def done(self, key: str) -> Optional[ProgressCommand]:
|
||||||
logger.info("cancelling job %s, active on device %s", key, device)
|
|
||||||
|
|
||||||
context = self.context[device]
|
|
||||||
context.set_cancel()
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
|
||||||
"""
|
"""
|
||||||
Check if a job has been finished and report the last progress update.
|
Check if a job has been finished and report the last progress update.
|
||||||
|
|
||||||
If the job is still active or pending, the first item will be False.
|
If the job is still active or pending, the first item will be False.
|
||||||
If the job is not finished or active, the first item will be None.
|
If the job is not finished or active, the first item will be None.
|
||||||
"""
|
"""
|
||||||
for k, p, c in self.finished_jobs:
|
for job in self.finished_jobs:
|
||||||
if k == key:
|
if job.job == key:
|
||||||
return (True, p)
|
return job
|
||||||
|
|
||||||
if key not in self.active_jobs:
|
if key not in self.active_jobs:
|
||||||
logger.debug("checking status for unknown job: %s", key)
|
logger.debug("checking status for unknown job: %s", key)
|
||||||
return (None, 0)
|
return None
|
||||||
|
|
||||||
_device, progress = self.active_jobs[key]
|
return self.active_jobs[key]
|
||||||
return (False, progress)
|
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
logger.info("stopping worker pool")
|
logger.info("stopping worker pool")
|
||||||
|
@ -387,22 +390,29 @@ class DevicePoolExecutor:
|
||||||
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
||||||
self.recycle()
|
self.recycle()
|
||||||
|
|
||||||
self.pending[device].put((key, fn, args, kwargs), block=False)
|
self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)
|
||||||
|
|
||||||
def status(self) -> List[Tuple[str, int, bool, bool]]:
|
def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
|
||||||
history = [
|
history = [
|
||||||
(name, progress, False, name in self.cancelled_jobs)
|
(
|
||||||
for name, (_device, progress) in self.active_jobs.items()
|
name,
|
||||||
|
job.progress,
|
||||||
|
job.finished,
|
||||||
|
job.cancel,
|
||||||
|
job.error,
|
||||||
|
)
|
||||||
|
for name, job in self.active_jobs.items()
|
||||||
]
|
]
|
||||||
history.extend(
|
history.extend(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
name,
|
job.job,
|
||||||
progress,
|
job.progress,
|
||||||
True,
|
job.finished,
|
||||||
cancel,
|
job.cancel,
|
||||||
|
job.error,
|
||||||
)
|
)
|
||||||
for name, progress, cancel in self.finished_jobs
|
for job in self.finished_jobs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
return history
|
return history
|
||||||
|
|
|
@ -39,32 +39,36 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
||||||
)
|
)
|
||||||
exit(EXIT_REPLACED)
|
exit(EXIT_REPLACED)
|
||||||
|
|
||||||
name, fn, args, kwargs = context.pending.get(timeout=1.0)
|
job = context.pending.get(timeout=1.0)
|
||||||
logger.info("worker for %s got job: %s", context.device.device, name)
|
logger.info("worker for %s got job: %s", context.device.device, job.name)
|
||||||
|
|
||||||
context.job = name # TODO: hax
|
context.job = job.name # TODO: hax
|
||||||
context.clear_flags()
|
logger.info("starting job: %s", job.name)
|
||||||
logger.info("starting job: %s", name)
|
context.set_progress(0)
|
||||||
fn(context, *args, **kwargs)
|
job.fn(context, *job.args, **job.kwargs)
|
||||||
logger.info("job succeeded: %s", name)
|
logger.info("job succeeded: %s", job.name)
|
||||||
context.set_finished()
|
context.set_finished()
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("worker got keyboard interrupt")
|
logger.info("worker got keyboard interrupt")
|
||||||
|
context.set_failed()
|
||||||
exit(EXIT_INTERRUPT)
|
exit(EXIT_INTERRUPT)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info(
|
logger.exception(
|
||||||
"value error in worker, exiting: %s",
|
"value error in worker, exiting: %s"
|
||||||
format_exception(type(e), e, e.__traceback__),
|
|
||||||
)
|
)
|
||||||
|
context.set_failed()
|
||||||
exit(EXIT_ERROR)
|
exit(EXIT_ERROR)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e_str = str(e)
|
e_str = str(e)
|
||||||
if "Failed to allocate memory" in e_str or "out of memory" in e_str:
|
if "Failed to allocate memory" in e_str or "out of memory" in e_str:
|
||||||
logger.error("detected out-of-memory error, exiting: %s", e)
|
logger.error("detected out-of-memory error, exiting: %s", e)
|
||||||
|
context.set_failed()
|
||||||
exit(EXIT_MEMORY)
|
exit(EXIT_MEMORY)
|
||||||
else:
|
else:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"error while running job",
|
"error while running job",
|
||||||
)
|
)
|
||||||
|
context.set_failed()
|
||||||
|
# carry on
|
||||||
|
|
|
@ -175,6 +175,8 @@ export interface ImageResponse {
|
||||||
* Status response from the ready endpoint.
|
* Status response from the ready endpoint.
|
||||||
*/
|
*/
|
||||||
export interface ReadyResponse {
|
export interface ReadyResponse {
|
||||||
|
cancel: boolean;
|
||||||
|
error: boolean;
|
||||||
progress: number;
|
progress: number;
|
||||||
ready: boolean;
|
ready: boolean;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue