fix(api): track currently active worker for each device
This commit is contained in:
parent
57fed94337
commit
c0a01efef4
|
@ -1,4 +1,5 @@
|
|||
from logging import getLogger
|
||||
from os import getpid
|
||||
from typing import Any, Callable, Tuple
|
||||
|
||||
from torch.multiprocessing import Queue, Value
|
||||
|
@ -15,6 +16,7 @@ class WorkerContext:
|
|||
cancel: "Value[bool]"
|
||||
job: str
|
||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
|
||||
current: "Value[int]"
|
||||
progress: "Queue[Tuple[str, str, int]]"
|
||||
|
||||
def __init__(
|
||||
|
@ -26,6 +28,7 @@ class WorkerContext:
|
|||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
|
||||
progress: "Queue[Tuple[str, str, int]]",
|
||||
finished: "Queue[Tuple[str, str]]",
|
||||
current: "Value[int]",
|
||||
):
|
||||
self.job = job
|
||||
self.device = device
|
||||
|
@ -34,10 +37,17 @@ class WorkerContext:
|
|||
self.finished = finished
|
||||
self.logs = logs
|
||||
self.pending = pending
|
||||
self.current = current
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel.value
|
||||
|
||||
def is_current(self) -> bool:
|
||||
if self.current.value > 0:
|
||||
return self.current.value == getpid()
|
||||
|
||||
return True
|
||||
|
||||
def get_device(self) -> DeviceParams:
|
||||
"""
|
||||
Get the device assigned to this job.
|
||||
|
|
|
@ -23,6 +23,7 @@ class DevicePoolExecutor:
|
|||
|
||||
leaking: List[Tuple[str, Process]]
|
||||
context: Dict[str, WorkerContext] # Device -> Context
|
||||
current: Dict[str, "Value[int]"]
|
||||
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
|
||||
threads: Dict[str, Thread]
|
||||
workers: Dict[str, Process]
|
||||
|
@ -52,6 +53,7 @@ class DevicePoolExecutor:
|
|||
|
||||
self.leaking = []
|
||||
self.context = {}
|
||||
self.current = {}
|
||||
self.pending = {}
|
||||
self.threads = {}
|
||||
self.workers = {}
|
||||
|
@ -85,6 +87,14 @@ class DevicePoolExecutor:
|
|||
pending = Queue(self.max_pending_per_worker)
|
||||
self.pending[name] = pending
|
||||
|
||||
if name in self.current:
|
||||
logger.debug("using existing current worker value")
|
||||
current = self.current[name]
|
||||
else:
|
||||
logger.debug("creating new current worker value")
|
||||
current = Value("L", 0)
|
||||
self.current[name] = current
|
||||
|
||||
context = WorkerContext(
|
||||
name,
|
||||
device,
|
||||
|
@ -93,16 +103,19 @@ class DevicePoolExecutor:
|
|||
finished=self.finished,
|
||||
logs=self.logs,
|
||||
pending=pending,
|
||||
current=current,
|
||||
)
|
||||
self.context[name] = context
|
||||
self.workers[name] = Process(
|
||||
worker = Process(
|
||||
name=f"onnx-web worker: {name}",
|
||||
target=worker_main,
|
||||
args=(context, self.server),
|
||||
)
|
||||
|
||||
logger.debug("starting worker for device %s", device)
|
||||
self.workers[name].start()
|
||||
worker.start()
|
||||
self.workers[name] = worker
|
||||
current.value = worker.pid
|
||||
|
||||
def create_logger_worker(self) -> None:
|
||||
def logger_worker(logs: Queue):
|
||||
|
|
|
@ -24,6 +24,10 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
|||
|
||||
while True:
|
||||
try:
|
||||
if not context.is_current():
|
||||
logger.warning("worker has been replaced, exiting")
|
||||
exit(3)
|
||||
|
||||
name, fn, args, kwargs = context.pending.get(timeout=1.0)
|
||||
logger.info("worker for %s got job: %s", context.device.device, name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue