1
0
Fork 0
onnx-web/api/onnx_web/worker/pool.py

149 lines
4.7 KiB
Python
Raw Normal View History

2023-02-26 05:49:39 +00:00
from collections import Counter
from logging import getLogger
from multiprocessing import Queue
from torch.multiprocessing import Lock, Process, Value
from typing import Callable, Dict, List, Optional, Tuple
from ..params import DeviceParams
from .context import WorkerContext
from .worker import logger_init, worker_init
logger = getLogger(__name__)
class DevicePoolExecutor:
devices: List[DeviceParams] = None
finished: List[Tuple[str, int]] = None
pending: Dict[str, "Queue[WorkerContext]"] = None
progress: Dict[str, Value] = None
workers: Dict[str, Process] = None
def __init__(
self,
devices: List[DeviceParams],
finished_limit: int = 10,
):
self.devices = devices
self.finished = []
self.finished_limit = finished_limit
2023-02-26 18:24:51 +00:00
self.context = {}
self.locks = {}
2023-02-26 05:49:39 +00:00
self.pending = {}
self.progress = {}
self.workers = {}
logger.debug("starting log worker")
2023-02-26 17:16:33 +00:00
self.log_queue = Queue()
self.logger = Process(target=logger_init, args=(self.lock, self.log_queue))
2023-02-26 05:49:39 +00:00
self.logger.start()
2023-02-26 17:16:33 +00:00
logger.debug("testing log worker")
self.log_queue.put("testing")
2023-02-26 05:49:39 +00:00
# create a pending queue and progress value for each device
for device in devices:
name = device.device
2023-02-26 18:24:51 +00:00
lock = Lock()
self.locks[name] = lock
cancel = Value("B", False, lock=lock)
progress = Value("I", 0, lock=lock)
self.progress[name] = progress
2023-02-26 05:49:39 +00:00
pending = Queue()
self.pending[name] = pending
2023-02-26 18:24:51 +00:00
context = WorkerContext(name, cancel, device, pending, progress)
self.context[name] = context
2023-02-26 05:49:39 +00:00
logger.debug("starting worker for device %s", device)
2023-02-26 18:24:51 +00:00
self.workers[name] = Process(target=worker_init, args=(lock, context))
2023-02-26 05:49:39 +00:00
self.workers[name].start()
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
raise NotImplementedError()
def done(self, key: str) -> Tuple[Optional[bool], int]:
for k, progress in self.finished:
if key == k:
return (True, progress)
logger.warn("checking status for unknown key: %s", key)
return (None, 0)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in range(len(self.devices)):
if self.devices[i].device == needs_device.device:
return i
pending = [
self.pending[d.device].qsize() for d in self.devices
]
jobs = Counter(range(len(self.devices)))
jobs.update(pending)
queued = jobs.most_common()
logger.debug("jobs queued by device: %s", queued)
lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
lowest_devices.sort()
return lowest_devices[0]
2023-02-26 16:47:31 +00:00
def join(self):
for device, worker in self.workers.items():
if worker.is_alive():
logger.info("stopping worker for device %s", device)
worker.join(5)
2023-02-26 17:16:33 +00:00
if self.logger.is_alive():
self.logger.join(5)
2023-02-26 05:49:39 +00:00
def prune(self):
finished_count = len(self.finished)
if finished_count > self.finished_limit:
logger.debug(
"pruning %s of %s finished jobs",
finished_count - self.finished_limit,
finished_count,
)
self.finished[:] = self.finished[-self.finished_limit:]
def submit(
self,
key: str,
fn: Callable[..., None],
/,
*args,
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
self.prune()
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
)
device = self.devices[device_idx]
queue = self.pending[device.device]
queue.put((fn, args, kwargs))
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
(
device.device,
self.pending[device.device].qsize(),
2023-02-26 18:24:51 +00:00
self.progress[device.device].value,
2023-02-26 05:49:39 +00:00
self.workers[device.device].is_alive(),
)
for device in self.devices
]
pending.extend(self.finished)
return pending