fix(api): track last progress within worker
This commit is contained in:
parent
5106dd48a9
commit
588c8c7fdb
|
@ -457,7 +457,9 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
|
|||
if path.exists(output):
|
||||
return ready_reply(True)
|
||||
else:
|
||||
return ready_reply(True, error=True) # is a missing image really an error? yes will display the retry button
|
||||
return ready_reply(
|
||||
True, error=True
|
||||
) # is a missing image really an error? yes will display the retry button
|
||||
|
||||
return ready_reply(
|
||||
progress.finished,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import getpid
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from torch.multiprocessing import Queue, Value
|
||||
|
||||
|
@ -17,8 +17,9 @@ class WorkerContext:
|
|||
cancel: "Value[bool]"
|
||||
job: str
|
||||
pending: "Queue[JobCommand]"
|
||||
current: "Value[int]"
|
||||
active_pid: "Value[int]"
|
||||
progress: "Queue[ProgressCommand]"
|
||||
last_progress: Optional[ProgressCommand]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -28,7 +29,7 @@ class WorkerContext:
|
|||
logs: "Queue[str]",
|
||||
pending: "Queue[JobCommand]",
|
||||
progress: "Queue[ProgressCommand]",
|
||||
current: "Value[int]",
|
||||
active_pid: "Value[int]",
|
||||
):
|
||||
self.job = job
|
||||
self.device = device
|
||||
|
@ -36,17 +37,17 @@ class WorkerContext:
|
|||
self.progress = progress
|
||||
self.logs = logs
|
||||
self.pending = pending
|
||||
self.current = current
|
||||
self.active_pid = active_pid
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel.value
|
||||
|
||||
def is_current(self) -> bool:
|
||||
return self.get_current() == getpid()
|
||||
def is_active(self) -> bool:
|
||||
return self.get_active() == getpid()
|
||||
|
||||
def get_current(self) -> int:
|
||||
with self.current.get_lock():
|
||||
return self.current.value
|
||||
def get_active(self) -> int:
|
||||
with self.active_pid.get_lock():
|
||||
return self.active_pid.value
|
||||
|
||||
def get_device(self) -> DeviceParams:
|
||||
"""
|
||||
|
@ -55,7 +56,10 @@ class WorkerContext:
|
|||
return self.device
|
||||
|
||||
def get_progress(self) -> int:
|
||||
return self.progress.value
|
||||
if self.last_progress is not None:
|
||||
return self.last_progress.progress
|
||||
|
||||
return 0
|
||||
|
||||
def get_progress_callback(self) -> ProgressCallback:
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
|
@ -73,44 +77,48 @@ class WorkerContext:
|
|||
raise RuntimeError("job has been cancelled")
|
||||
else:
|
||||
logger.debug("setting progress for job %s to %s", self.job, progress)
|
||||
self.last_progress = ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
False,
|
||||
progress,
|
||||
self.is_cancelled(),
|
||||
False,
|
||||
)
|
||||
|
||||
self.progress.put(
|
||||
ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
False,
|
||||
progress,
|
||||
self.is_cancelled(),
|
||||
False,
|
||||
),
|
||||
self.last_progress,
|
||||
block=False,
|
||||
)
|
||||
|
||||
def set_finished(self) -> None:
|
||||
logger.debug("setting finished for job %s", self.job)
|
||||
self.last_progress = ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
True,
|
||||
self.get_progress(),
|
||||
self.is_cancelled(),
|
||||
False,
|
||||
)
|
||||
self.progress.put(
|
||||
ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
True,
|
||||
self.get_progress(),
|
||||
self.is_cancelled(),
|
||||
False,
|
||||
),
|
||||
self.last_progress,
|
||||
block=False,
|
||||
)
|
||||
|
||||
def set_failed(self) -> None:
|
||||
logger.warning("setting failure for job %s", self.job)
|
||||
try:
|
||||
self.last_progress = ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
True,
|
||||
self.get_progress(),
|
||||
self.is_cancelled(),
|
||||
True,
|
||||
)
|
||||
self.progress.put(
|
||||
ProgressCommand(
|
||||
self.job,
|
||||
self.device.device,
|
||||
True,
|
||||
self.get_progress(),
|
||||
self.is_cancelled(),
|
||||
True,
|
||||
),
|
||||
self.last_progress,
|
||||
block=False,
|
||||
)
|
||||
except Exception:
|
||||
|
|
|
@ -99,7 +99,7 @@ class DevicePoolExecutor:
|
|||
progress=self.progress,
|
||||
logs=self.logs,
|
||||
pending=pending,
|
||||
current=current,
|
||||
active_pid=current,
|
||||
)
|
||||
self.context[name] = context
|
||||
worker = Process(
|
||||
|
|
|
@ -29,11 +29,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
|||
|
||||
while True:
|
||||
try:
|
||||
if not context.is_current():
|
||||
if not context.is_active():
|
||||
logger.warning(
|
||||
"worker %s has been replaced by %s, exiting",
|
||||
getpid(),
|
||||
context.get_current(),
|
||||
context.get_active(),
|
||||
)
|
||||
exit(EXIT_REPLACED)
|
||||
|
||||
|
|
Loading…
Reference in New Issue