1
0
Fork 0

fix(api): track last progress within worker

This commit is contained in:
Sean Sube 2023-03-18 15:32:49 -05:00
parent 5106dd48a9
commit 588c8c7fdb
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 48 additions and 38 deletions

View File

@ -457,7 +457,9 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
if path.exists(output): if path.exists(output):
return ready_reply(True) return ready_reply(True)
else: else:
return ready_reply(True, error=True) # is a missing image really an error? yes will display the retry button return ready_reply(
True, error=True
) # is a missing image really an error? yes will display the retry button
return ready_reply( return ready_reply(
progress.finished, progress.finished,

View File

@ -1,6 +1,6 @@
from logging import getLogger from logging import getLogger
from os import getpid from os import getpid
from typing import Any, Callable from typing import Any, Callable, Optional
from torch.multiprocessing import Queue, Value from torch.multiprocessing import Queue, Value
@ -17,8 +17,9 @@ class WorkerContext:
cancel: "Value[bool]" cancel: "Value[bool]"
job: str job: str
pending: "Queue[JobCommand]" pending: "Queue[JobCommand]"
current: "Value[int]" active_pid: "Value[int]"
progress: "Queue[ProgressCommand]" progress: "Queue[ProgressCommand]"
last_progress: Optional[ProgressCommand]
def __init__( def __init__(
self, self,
@ -28,7 +29,7 @@ class WorkerContext:
logs: "Queue[str]", logs: "Queue[str]",
pending: "Queue[JobCommand]", pending: "Queue[JobCommand]",
progress: "Queue[ProgressCommand]", progress: "Queue[ProgressCommand]",
current: "Value[int]", active_pid: "Value[int]",
): ):
self.job = job self.job = job
self.device = device self.device = device
@ -36,17 +37,17 @@ class WorkerContext:
self.progress = progress self.progress = progress
self.logs = logs self.logs = logs
self.pending = pending self.pending = pending
self.current = current self.active_pid = active_pid
def is_cancelled(self) -> bool: def is_cancelled(self) -> bool:
return self.cancel.value return self.cancel.value
def is_current(self) -> bool: def is_active(self) -> bool:
return self.get_current() == getpid() return self.get_active() == getpid()
def get_current(self) -> int: def get_active(self) -> int:
with self.current.get_lock(): with self.active_pid.get_lock():
return self.current.value return self.active_pid.value
def get_device(self) -> DeviceParams: def get_device(self) -> DeviceParams:
""" """
@ -55,7 +56,10 @@ class WorkerContext:
return self.device return self.device
def get_progress(self) -> int: def get_progress(self) -> int:
return self.progress.value if self.last_progress is not None:
return self.last_progress.progress
return 0
def get_progress_callback(self) -> ProgressCallback: def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any): def on_progress(step: int, timestep: int, latents: Any):
@ -73,44 +77,48 @@ class WorkerContext:
raise RuntimeError("job has been cancelled") raise RuntimeError("job has been cancelled")
else: else:
logger.debug("setting progress for job %s to %s", self.job, progress) logger.debug("setting progress for job %s to %s", self.job, progress)
self.last_progress = ProgressCommand(
self.job,
self.device.device,
False,
progress,
self.is_cancelled(),
False,
)
self.progress.put( self.progress.put(
ProgressCommand( self.last_progress,
self.job,
self.device.device,
False,
progress,
self.is_cancelled(),
False,
),
block=False, block=False,
) )
def set_finished(self) -> None: def set_finished(self) -> None:
logger.debug("setting finished for job %s", self.job) logger.debug("setting finished for job %s", self.job)
self.last_progress = ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
False,
)
self.progress.put( self.progress.put(
ProgressCommand( self.last_progress,
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
False,
),
block=False, block=False,
) )
def set_failed(self) -> None: def set_failed(self) -> None:
logger.warning("setting failure for job %s", self.job) logger.warning("setting failure for job %s", self.job)
try: try:
self.last_progress = ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
True,
)
self.progress.put( self.progress.put(
ProgressCommand( self.last_progress,
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
True,
),
block=False, block=False,
) )
except Exception: except Exception:

View File

@ -99,7 +99,7 @@ class DevicePoolExecutor:
progress=self.progress, progress=self.progress,
logs=self.logs, logs=self.logs,
pending=pending, pending=pending,
current=current, active_pid=current,
) )
self.context[name] = context self.context[name] = context
worker = Process( worker = Process(

View File

@ -29,11 +29,11 @@ def worker_main(context: WorkerContext, server: ServerContext):
while True: while True:
try: try:
if not context.is_current(): if not context.is_active():
logger.warning( logger.warning(
"worker %s has been replaced by %s, exiting", "worker %s has been replaced by %s, exiting",
getpid(), getpid(),
context.get_current(), context.get_active(),
) )
exit(EXIT_REPLACED) exit(EXIT_REPLACED)