1
0
Fork 0
onnx-web/api/onnx_web/worker/context.py

187 lines
5.0 KiB
Python

from logging import getLogger
from os import getpid
from typing import Any, Callable, Optional
from torch.multiprocessing import Queue, Value
from ..errors import CancelledException
from ..params import DeviceParams
from .command import JobCommand, ProgressCommand
logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None]
class WorkerContext:
cancel: "Value[bool]"
job: Optional[str]
name: str
pending: "Queue[JobCommand]"
active_pid: "Value[int]"
progress: "Queue[ProgressCommand]"
last_progress: Optional[ProgressCommand]
idle: "Value[bool]"
timeout: float
retries: int
initial_retries: int
def __init__(
self,
name: str,
device: DeviceParams,
cancel: "Value[bool]",
logs: "Queue[str]",
pending: "Queue[JobCommand]",
progress: "Queue[ProgressCommand]",
active_pid: "Value[int]",
idle: "Value[bool]",
retries: int,
timeout: float,
):
self.job = None
self.name = name
self.device = device
self.cancel = cancel
self.progress = progress
self.logs = logs
self.pending = pending
self.active_pid = active_pid
self.last_progress = None
self.idle = idle
self.initial_retries = retries
self.retries = retries
self.timeout = timeout
def start(self, job: str) -> None:
self.job = job
self.retries = self.initial_retries
self.set_cancel(cancel=False)
self.set_idle(idle=False)
def is_active(self) -> bool:
return self.get_active() == getpid()
def is_cancelled(self) -> bool:
return self.cancel.value
def is_idle(self) -> bool:
return self.idle.value
def get_active(self) -> int:
with self.active_pid.get_lock():
return self.active_pid.value
def get_device(self) -> DeviceParams:
"""
Get the device assigned to this job.
"""
return self.device
def get_progress(self) -> int:
if self.last_progress is not None:
return self.last_progress.progress
return 0
def get_progress_callback(self) -> ProgressCallback:
from ..chain.base import ChainProgress
def on_progress(step: int, timestep: int, latents: Any):
on_progress.step = step
self.set_progress(step)
return ChainProgress.from_progress(on_progress)
def set_cancel(self, cancel: bool = True) -> None:
with self.cancel.get_lock():
self.cancel.value = cancel
def set_idle(self, idle: bool = True) -> None:
with self.idle.get_lock():
self.idle.value = idle
def set_progress(self, progress: int) -> None:
if self.job is None:
raise RuntimeError("no job on which to set progress")
if self.is_cancelled():
raise CancelledException("job has been cancelled")
logger.debug("setting progress for job %s to %s", self.job, progress)
self.last_progress = ProgressCommand(
self.job,
self.device.device,
False,
progress,
self.is_cancelled(),
False,
)
self.progress.put(
self.last_progress,
block=False,
)
def finish(self) -> None:
if self.job is None:
logger.warning("setting finished without an active job")
else:
logger.debug("setting finished for job %s", self.job)
self.last_progress = ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
False,
)
self.progress.put(
self.last_progress,
block=False,
)
def fail(self) -> None:
if self.job is None:
logger.warning("setting failure without an active job")
else:
logger.warning("setting failure for job %s", self.job)
try:
self.last_progress = ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
True,
)
self.progress.put(
self.last_progress,
block=False,
)
except Exception:
logger.exception("error setting failure on job %s", self.job)
class JobStatus:
name: str
device: str
progress: int
cancelled: bool
finished: bool
def __init__(
self,
name: str,
device: DeviceParams,
progress: int = 0,
cancelled: bool = False,
finished: bool = False,
) -> None:
self.name = name
self.device = device.device
self.progress = progress
self.cancelled = cancelled
self.finished = finished