1
0
Fork 0

fix(api): track currently active worker for each device

This commit is contained in:
Sean Sube 2023-03-05 21:28:21 -06:00
parent 57fed94337
commit c0a01efef4
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 29 additions and 2 deletions

View File

@ -1,4 +1,5 @@
from logging import getLogger
from os import getpid
from typing import Any, Callable, Tuple
from torch.multiprocessing import Queue, Value
@ -15,6 +16,7 @@ class WorkerContext:
cancel: "Value[bool]"
job: str
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]"
current: "Value[int]"
progress: "Queue[Tuple[str, str, int]]"
def __init__(
@ -26,6 +28,7 @@ class WorkerContext:
pending: "Queue[Tuple[str, Callable[..., None], Any, Any]]",
progress: "Queue[Tuple[str, str, int]]",
finished: "Queue[Tuple[str, str]]",
current: "Value[int]",
):
self.job = job
self.device = device
@ -34,10 +37,17 @@ class WorkerContext:
self.finished = finished
self.logs = logs
self.pending = pending
self.current = current
def is_cancelled(self) -> bool:
return self.cancel.value
def is_current(self) -> bool:
if self.current.value > 0:
return self.current.value == getpid()
return True
def get_device(self) -> DeviceParams:
"""
Get the device assigned to this job.

View File

@ -23,6 +23,7 @@ class DevicePoolExecutor:
leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"]
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
threads: Dict[str, Thread]
workers: Dict[str, Process]
@ -52,6 +53,7 @@ class DevicePoolExecutor:
self.leaking = []
self.context = {}
self.current = {}
self.pending = {}
self.threads = {}
self.workers = {}
@ -85,6 +87,14 @@ class DevicePoolExecutor:
pending = Queue(self.max_pending_per_worker)
self.pending[name] = pending
if name in self.current:
logger.debug("using existing current worker value")
current = self.current[name]
else:
logger.debug("creating new current worker value")
current = Value("L", 0)
self.current[name] = current
context = WorkerContext(
name,
device,
@ -93,16 +103,19 @@ class DevicePoolExecutor:
finished=self.finished,
logs=self.logs,
pending=pending,
current=current,
)
self.context[name] = context
self.workers[name] = Process(
worker = Process(
name=f"onnx-web worker: {name}",
target=worker_main,
args=(context, self.server),
)
logger.debug("starting worker for device %s", device)
self.workers[name].start()
worker.start()
self.workers[name] = worker
current.value = worker.pid
def create_logger_worker(self) -> None:
def logger_worker(logs: Queue):

View File

@ -24,6 +24,10 @@ def worker_main(context: WorkerContext, server: ServerContext):
while True:
try:
if not context.is_current():
logger.warning("worker has been replaced, exiting")
exit(3)
name, fn, args, kwargs = context.pending.get(timeout=1.0)
logger.info("worker for %s got job: %s", context.device.device, name)