1
0
Fork 0

fix(api): track last progress within worker

This commit is contained in:
Sean Sube 2023-03-18 15:32:49 -05:00
parent 5106dd48a9
commit 588c8c7fdb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 48 additions and 38 deletions

View File

@ -457,7 +457,9 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
if path.exists(output):
return ready_reply(True)
else:
return ready_reply(True, error=True) # is a missing image really an error? yes will display the retry button
return ready_reply(
True, error=True
) # is a missing image really an error? yes will display the retry button
return ready_reply(
progress.finished,

View File

@ -1,6 +1,6 @@
from logging import getLogger
from os import getpid
from typing import Any, Callable
from typing import Any, Callable, Optional
from torch.multiprocessing import Queue, Value
@ -17,8 +17,9 @@ class WorkerContext:
cancel: "Value[bool]"
job: str
pending: "Queue[JobCommand]"
current: "Value[int]"
active_pid: "Value[int]"
progress: "Queue[ProgressCommand]"
last_progress: Optional[ProgressCommand]
def __init__(
self,
@ -28,7 +29,7 @@ class WorkerContext:
logs: "Queue[str]",
pending: "Queue[JobCommand]",
progress: "Queue[ProgressCommand]",
current: "Value[int]",
active_pid: "Value[int]",
):
self.job = job
self.device = device
@ -36,17 +37,17 @@ class WorkerContext:
self.progress = progress
self.logs = logs
self.pending = pending
self.current = current
self.active_pid = active_pid
def is_cancelled(self) -> bool:
return self.cancel.value
def is_current(self) -> bool:
return self.get_current() == getpid()
def is_active(self) -> bool:
return self.get_active() == getpid()
def get_current(self) -> int:
with self.current.get_lock():
return self.current.value
def get_active(self) -> int:
with self.active_pid.get_lock():
return self.active_pid.value
def get_device(self) -> DeviceParams:
"""
@ -55,7 +56,10 @@ class WorkerContext:
return self.device
def get_progress(self) -> int:
return self.progress.value
if self.last_progress is not None:
return self.last_progress.progress
return 0
def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any):
@ -73,44 +77,48 @@ class WorkerContext:
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.job, progress)
self.progress.put(
ProgressCommand(
self.last_progress = ProgressCommand(
self.job,
self.device.device,
False,
progress,
self.is_cancelled(),
False,
),
)
self.progress.put(
self.last_progress,
block=False,
)
def set_finished(self) -> None:
logger.debug("setting finished for job %s", self.job)
self.progress.put(
ProgressCommand(
self.last_progress = ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
False,
),
)
self.progress.put(
self.last_progress,
block=False,
)
def set_failed(self) -> None:
logger.warning("setting failure for job %s", self.job)
try:
self.progress.put(
ProgressCommand(
self.last_progress = ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
True,
),
)
self.progress.put(
self.last_progress,
block=False,
)
except Exception:

View File

@ -99,7 +99,7 @@ class DevicePoolExecutor:
progress=self.progress,
logs=self.logs,
pending=pending,
current=current,
active_pid=current,
)
self.context[name] = context
worker = Process(

View File

@ -29,11 +29,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
while True:
try:
if not context.is_current():
if not context.is_active():
logger.warning(
"worker %s has been replaced by %s, exiting",
getpid(),
context.get_current(),
context.get_active(),
)
exit(EXIT_REPLACED)