1
0
Fork 0
onnx-web/api/onnx_web/worker/pool.py

143 lines
4.5 KiB
Python
Raw Normal View History

2023-02-26 05:49:39 +00:00
from collections import Counter
from logging import getLogger
from multiprocessing import Queue
from torch.multiprocessing import Lock, Process, Value
from typing import Callable, Dict, List, Optional, Tuple
from ..params import DeviceParams
from .context import WorkerContext
from .worker import logger_init, worker_init
logger = getLogger(__name__)
class DevicePoolExecutor:
devices: List[DeviceParams] = None
finished: List[Tuple[str, int]] = None
pending: Dict[str, "Queue[WorkerContext]"] = None
progress: Dict[str, Value] = None
workers: Dict[str, Process] = None
def __init__(
self,
devices: List[DeviceParams],
finished_limit: int = 10,
):
self.devices = devices
self.finished = []
self.finished_limit = finished_limit
self.lock = Lock()
self.pending = {}
self.progress = {}
self.workers = {}
log_queue = Queue()
logger_context = WorkerContext("logger", None, None, log_queue, None)
logger.debug("starting log worker")
self.logger = Process(target=logger_init, args=(self.lock, logger_context))
self.logger.start()
# create a pending queue and progress value for each device
for device in devices:
name = device.device
cancel = Value("B", False, lock=self.lock)
progress = Value("I", 0, lock=self.lock)
pending = Queue()
context = WorkerContext(name, cancel, device, pending, progress)
self.pending[name] = pending
self.progress[name] = pending
logger.debug("starting worker for device %s", device)
self.workers[name] = Process(target=worker_init, args=(self.lock, context))
self.workers[name].start()
logger.debug("testing log worker")
log_queue.put("testing")
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
raise NotImplementedError()
def done(self, key: str) -> Tuple[Optional[bool], int]:
for k, progress in self.finished:
if key == k:
return (True, progress)
logger.warn("checking status for unknown key: %s", key)
return (None, 0)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in range(len(self.devices)):
if self.devices[i].device == needs_device.device:
return i
pending = [
self.pending[d.device].qsize() for d in self.devices
]
jobs = Counter(range(len(self.devices)))
jobs.update(pending)
queued = jobs.most_common()
logger.debug("jobs queued by device: %s", queued)
lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
lowest_devices.sort()
return lowest_devices[0]
2023-02-26 16:47:31 +00:00
def join(self):
for device, worker in self.workers.items():
if worker.is_alive():
logger.info("stopping worker for device %s", device)
worker.join(5)
2023-02-26 05:49:39 +00:00
def prune(self):
finished_count = len(self.finished)
if finished_count > self.finished_limit:
logger.debug(
"pruning %s of %s finished jobs",
finished_count - self.finished_limit,
finished_count,
)
self.finished[:] = self.finished[-self.finished_limit:]
def submit(
self,
key: str,
fn: Callable[..., None],
/,
*args,
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
self.prune()
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
)
device = self.devices[device_idx]
queue = self.pending[device.device]
queue.put((fn, args, kwargs))
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
(
device.device,
self.pending[device.device].qsize(),
self.workers[device.device].is_alive(),
)
for device in self.devices
]
pending.extend(self.finished)
return pending