fix(api): track currently active worker for each device
This commit is contained in:
parent
57fed94337
commit
c0a01efef4
|
@ -1,4 +1,5 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
|
from os import getpid
|
||||||
from typing import Any, Callable, Tuple
|
from typing import Any, Callable, Tuple
|
||||||
|
|
||||||
from torch.multiprocessing import Queue, Value
|
from torch.multiprocessing import Queue, Value
|
||||||
|
@ -15,6 +16,7 @@ class WorkerContext:
|
||||||
cancel: "Value[bool]"
|
cancel: "Value[bool]"
|
||||||
job: str
|
job: str
|
||||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
|
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
|
||||||
|
current: "Value[int]"
|
||||||
progress: "Queue[Tuple[str, str, int]]"
|
progress: "Queue[Tuple[str, str, int]]"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -26,6 +28,7 @@ class WorkerContext:
|
||||||
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
|
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
|
||||||
progress: "Queue[Tuple[str, str, int]]",
|
progress: "Queue[Tuple[str, str, int]]",
|
||||||
finished: "Queue[Tuple[str, str]]",
|
finished: "Queue[Tuple[str, str]]",
|
||||||
|
current: "Value[int]",
|
||||||
):
|
):
|
||||||
self.job = job
|
self.job = job
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -34,10 +37,17 @@ class WorkerContext:
|
||||||
self.finished = finished
|
self.finished = finished
|
||||||
self.logs = logs
|
self.logs = logs
|
||||||
self.pending = pending
|
self.pending = pending
|
||||||
|
self.current = current
|
||||||
|
|
||||||
def is_cancelled(self) -> bool:
|
def is_cancelled(self) -> bool:
|
||||||
return self.cancel.value
|
return self.cancel.value
|
||||||
|
|
||||||
|
def is_current(self) -> bool:
|
||||||
|
if self.current.value > 0:
|
||||||
|
return self.current.value == getpid()
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def get_device(self) -> DeviceParams:
|
def get_device(self) -> DeviceParams:
|
||||||
"""
|
"""
|
||||||
Get the device assigned to this job.
|
Get the device assigned to this job.
|
||||||
|
|
|
@ -23,6 +23,7 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
leaking: List[Tuple[str, Process]]
|
leaking: List[Tuple[str, Process]]
|
||||||
context: Dict[str, WorkerContext] # Device -> Context
|
context: Dict[str, WorkerContext] # Device -> Context
|
||||||
|
current: Dict[str, "Value[int]"]
|
||||||
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
|
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
|
||||||
threads: Dict[str, Thread]
|
threads: Dict[str, Thread]
|
||||||
workers: Dict[str, Process]
|
workers: Dict[str, Process]
|
||||||
|
@ -52,6 +53,7 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
self.leaking = []
|
self.leaking = []
|
||||||
self.context = {}
|
self.context = {}
|
||||||
|
self.current = {}
|
||||||
self.pending = {}
|
self.pending = {}
|
||||||
self.threads = {}
|
self.threads = {}
|
||||||
self.workers = {}
|
self.workers = {}
|
||||||
|
@ -85,6 +87,14 @@ class DevicePoolExecutor:
|
||||||
pending = Queue(self.max_pending_per_worker)
|
pending = Queue(self.max_pending_per_worker)
|
||||||
self.pending[name] = pending
|
self.pending[name] = pending
|
||||||
|
|
||||||
|
if name in self.current:
|
||||||
|
logger.debug("using existing current worker value")
|
||||||
|
current = self.current[name]
|
||||||
|
else:
|
||||||
|
logger.debug("creating new current worker value")
|
||||||
|
current = Value("L", 0)
|
||||||
|
self.current[name] = current
|
||||||
|
|
||||||
context = WorkerContext(
|
context = WorkerContext(
|
||||||
name,
|
name,
|
||||||
device,
|
device,
|
||||||
|
@ -93,16 +103,19 @@ class DevicePoolExecutor:
|
||||||
finished=self.finished,
|
finished=self.finished,
|
||||||
logs=self.logs,
|
logs=self.logs,
|
||||||
pending=pending,
|
pending=pending,
|
||||||
|
current=current,
|
||||||
)
|
)
|
||||||
self.context[name] = context
|
self.context[name] = context
|
||||||
self.workers[name] = Process(
|
worker = Process(
|
||||||
name=f"onnx-web worker: {name}",
|
name=f"onnx-web worker: {name}",
|
||||||
target=worker_main,
|
target=worker_main,
|
||||||
args=(context, self.server),
|
args=(context, self.server),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("starting worker for device %s", device)
|
logger.debug("starting worker for device %s", device)
|
||||||
self.workers[name].start()
|
worker.start()
|
||||||
|
self.workers[name] = worker
|
||||||
|
current.value = worker.pid
|
||||||
|
|
||||||
def create_logger_worker(self) -> None:
|
def create_logger_worker(self) -> None:
|
||||||
def logger_worker(logs: Queue):
|
def logger_worker(logs: Queue):
|
||||||
|
|
|
@ -24,6 +24,10 @@ def worker_main(context: WorkerContext, server: ServerContext):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
if not context.is_current():
|
||||||
|
logger.warning("worker has been replaced, exiting")
|
||||||
|
exit(3)
|
||||||
|
|
||||||
name, fn, args, kwargs = context.pending.get(timeout=1.0)
|
name, fn, args, kwargs = context.pending.get(timeout=1.0)
|
||||||
logger.info("worker for %s got job: %s", context.device.device, name)
|
logger.info("worker for %s got job: %s", context.device.device, name)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue