1
0
Fork 0

fix(api): track currently active worker for each device

This commit is contained in:
Sean Sube 2023-03-05 21:28:21 -06:00
parent 57fed94337
commit c0a01efef4
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 29 additions and 2 deletions

View File

@ -1,4 +1,5 @@
from logging import getLogger from logging import getLogger
from os import getpid
from typing import Any, Callable, Tuple from typing import Any, Callable, Tuple
from torch.multiprocessing import Queue, Value from torch.multiprocessing import Queue, Value
@ -15,6 +16,7 @@ class WorkerContext:
cancel: "Value[bool]" cancel: "Value[bool]"
job: str job: str
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]" pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
current: "Value[int]"
progress: "Queue[Tuple[str, str, int]]" progress: "Queue[Tuple[str, str, int]]"
def __init__( def __init__(
@ -26,6 +28,7 @@ class WorkerContext:
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]", pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
progress: "Queue[Tuple[str, str, int]]", progress: "Queue[Tuple[str, str, int]]",
finished: "Queue[Tuple[str, str]]", finished: "Queue[Tuple[str, str]]",
current: "Value[int]",
): ):
self.job = job self.job = job
self.device = device self.device = device
@ -34,10 +37,17 @@ class WorkerContext:
self.finished = finished self.finished = finished
self.logs = logs self.logs = logs
self.pending = pending self.pending = pending
self.current = current
def is_cancelled(self) -> bool: def is_cancelled(self) -> bool:
return self.cancel.value return self.cancel.value
def is_current(self) -> bool:
if self.current.value > 0:
return self.current.value == getpid()
return True
def get_device(self) -> DeviceParams: def get_device(self) -> DeviceParams:
""" """
Get the device assigned to this job. Get the device assigned to this job.

View File

@ -23,6 +23,7 @@ class DevicePoolExecutor:
leaking: List[Tuple[str, Process]] leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"]
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"] pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
threads: Dict[str, Thread] threads: Dict[str, Thread]
workers: Dict[str, Process] workers: Dict[str, Process]
@ -52,6 +53,7 @@ class DevicePoolExecutor:
self.leaking = [] self.leaking = []
self.context = {} self.context = {}
self.current = {}
self.pending = {} self.pending = {}
self.threads = {} self.threads = {}
self.workers = {} self.workers = {}
@ -85,6 +87,14 @@ class DevicePoolExecutor:
pending = Queue(self.max_pending_per_worker) pending = Queue(self.max_pending_per_worker)
self.pending[name] = pending self.pending[name] = pending
if name in self.current:
logger.debug("using existing current worker value")
current = self.current[name]
else:
logger.debug("creating new current worker value")
current = Value("L", 0)
self.current[name] = current
context = WorkerContext( context = WorkerContext(
name, name,
device, device,
@ -93,16 +103,19 @@ class DevicePoolExecutor:
finished=self.finished, finished=self.finished,
logs=self.logs, logs=self.logs,
pending=pending, pending=pending,
current=current,
) )
self.context[name] = context self.context[name] = context
self.workers[name] = Process( worker = Process(
name=f"onnx-web worker: {name}", name=f"onnx-web worker: {name}",
target=worker_main, target=worker_main,
args=(context, self.server), args=(context, self.server),
) )
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)
self.workers[name].start() worker.start()
self.workers[name] = worker
current.value = worker.pid
def create_logger_worker(self) -> None: def create_logger_worker(self) -> None:
def logger_worker(logs: Queue): def logger_worker(logs: Queue):

View File

@ -24,6 +24,10 @@ def worker_main(context: WorkerContext, server: ServerContext):
while True: while True:
try: try:
if not context.is_current():
logger.warning("worker has been replaced, exiting")
exit(3)
name, fn, args, kwargs = context.pending.get(timeout=1.0) name, fn, args, kwargs = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, name) logger.info("worker for %s got job: %s", context.device.device, name)