1
0
Fork 0
onnx-web/api/onnx_web/worker/pool.py

251 lines
8.0 KiB
Python
Raw Normal View History

2023-02-26 05:49:39 +00:00
from collections import Counter
from logging import getLogger
2023-02-27 02:09:42 +00:00
from threading import Thread
2023-02-26 05:49:39 +00:00
from typing import Callable, Dict, List, Optional, Tuple
2023-02-27 02:09:42 +00:00
from torch.multiprocessing import Process, Queue, Value
2023-02-26 20:15:30 +00:00
2023-02-26 05:49:39 +00:00
from ..params import DeviceParams
2023-02-26 18:32:48 +00:00
from ..server import ServerContext
2023-02-26 05:49:39 +00:00
from .context import WorkerContext
from .worker import logger_init, worker_init
logger = getLogger(__name__)
class DevicePoolExecutor:
2023-02-27 02:09:42 +00:00
context: Dict[str, WorkerContext] = None # Device -> Context
2023-02-26 05:49:39 +00:00
devices: List[DeviceParams] = None
pending: Dict[str, "Queue[WorkerContext]"] = None
workers: Dict[str, Process] = None
2023-02-27 02:37:22 +00:00
active_jobs: Dict[str, Tuple[str, int]] = None
2023-02-27 02:09:42 +00:00
finished_jobs: List[Tuple[str, int, bool]] = None
2023-02-26 05:49:39 +00:00
def __init__(
self,
2023-02-26 18:32:48 +00:00
server: ServerContext,
2023-02-26 05:49:39 +00:00
devices: List[DeviceParams],
max_jobs_per_worker: int = 10,
join_timeout: float = 5.0,
2023-02-26 05:49:39 +00:00
):
2023-02-26 18:32:48 +00:00
self.server = server
2023-02-26 05:49:39 +00:00
self.devices = devices
self.max_jobs_per_worker = max_jobs_per_worker
self.join_timeout = join_timeout
2023-02-26 18:24:51 +00:00
self.context = {}
2023-02-26 05:49:39 +00:00
self.pending = {}
self.workers = {}
2023-02-27 02:09:42 +00:00
self.active_jobs = {}
self.finished_jobs = []
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
2023-02-27 02:37:22 +00:00
self.progress = Queue()
2023-02-27 02:09:42 +00:00
self.finished = Queue()
2023-02-26 05:49:39 +00:00
self.create_logger_worker()
2023-02-27 02:09:42 +00:00
self.create_queue_workers()
for device in devices:
self.create_device_worker(device)
2023-02-26 05:49:39 +00:00
2023-02-26 17:16:33 +00:00
logger.debug("testing log worker")
self.log_queue.put("testing")
def create_logger_worker(self) -> None:
self.log_queue = Queue()
2023-02-26 21:21:58 +00:00
self.logger = Process(target=logger_init, args=(self.log_queue,))
2023-02-26 05:49:39 +00:00
logger.debug("starting log worker")
self.logger.start()
def create_device_worker(self, device: DeviceParams) -> None:
name = device.device
2023-02-27 02:09:42 +00:00
# reuse the queue if possible, to keep queued jobs
if name in self.pending:
pending = self.pending[name]
else:
pending = Queue()
self.pending[name] = pending
context = WorkerContext(
name,
device,
cancel=Value("B", False),
2023-02-27 02:37:22 +00:00
progress=self.progress,
2023-02-27 02:09:42 +00:00
finished=self.finished,
logs=self.log_queue,
2023-02-27 02:09:42 +00:00
pending=pending,
)
self.context[name] = context
self.workers[name] = Process(target=worker_init, args=(context, self.server))
logger.debug("starting worker for device %s", device)
self.workers[name].start()
2023-02-27 02:09:42 +00:00
def create_queue_workers(self) -> None:
2023-02-27 02:37:22 +00:00
def progress_worker(progress: Queue):
logger.info("checking in from progress worker thread")
2023-02-27 02:09:42 +00:00
while True:
2023-02-27 02:37:22 +00:00
job, device, value = progress.get()
logger.info("progress update for job: %s, %s", job, value)
self.active_jobs[job] = (device, value)
2023-02-27 02:09:42 +00:00
def finished_worker(finished: Queue):
2023-02-27 02:37:22 +00:00
logger.info("checking in from finished worker thread")
2023-02-27 02:09:42 +00:00
while True:
job, device = finished.get()
logger.info("job has been finished: %s", job)
2023-02-27 02:37:22 +00:00
context = self.context[device]
_device, progress = self.active_jobs[job]
2023-02-27 02:09:42 +00:00
self.finished_jobs.append(
2023-02-27 02:37:22 +00:00
(job, progress, context.cancel.value)
2023-02-27 02:09:42 +00:00
)
2023-02-27 02:37:22 +00:00
del self.active_jobs[job]
2023-02-27 02:09:42 +00:00
2023-02-27 02:37:22 +00:00
self.progress_thread = Thread(target=progress_worker, args=(self.progress,))
self.progress_thread.start()
2023-02-27 02:09:42 +00:00
self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
self.finished_thread.start()
def get_job_context(self, key: str) -> WorkerContext:
2023-02-27 02:37:22 +00:00
device, _progress = self.active_jobs[key]
2023-02-27 02:09:42 +00:00
return self.context[device]
2023-02-26 05:49:39 +00:00
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
2023-02-27 02:09:42 +00:00
if key not in self.active_jobs:
2023-02-26 20:36:32 +00:00
logger.warn("attempting to cancel unknown job: %s", key)
return False
2023-02-27 02:09:42 +00:00
device = self.active_jobs[key]
context = self.context[device]
2023-02-26 20:36:32 +00:00
logger.info("cancelling job %s on device %s", key, device)
if context.cancel.get_lock():
context.cancel.value = True
2023-02-26 20:36:32 +00:00
# self.finished.append((key, context.progress.value, context.cancel.value)) maybe?
2023-02-26 20:36:32 +00:00
return True
2023-02-26 05:49:39 +00:00
def done(self, key: str) -> Tuple[Optional[bool], int]:
2023-02-27 02:09:42 +00:00
for k, p, c in self.finished_jobs:
if k == key:
2023-02-27 02:13:16 +00:00
return (True, p)
2023-02-27 02:09:42 +00:00
if key not in self.active_jobs:
2023-02-26 20:36:32 +00:00
logger.warn("checking status for unknown job: %s", key)
return (None, 0)
# TODO: prune here, maybe?
2023-02-27 02:37:22 +00:00
_device, progress = self.active_jobs[key]
return (False, progress)
2023-02-26 05:49:39 +00:00
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in range(len(self.devices)):
if self.devices[i].device == needs_device.device:
return i
jobs = Counter(range(len(self.devices)))
jobs.update([self.pending[d.device].qsize() for d in self.devices])
2023-02-26 05:49:39 +00:00
queued = jobs.most_common()
logger.debug("jobs queued by device: %s", queued)
lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
lowest_devices.sort()
return lowest_devices[0]
2023-02-26 16:47:31 +00:00
def join(self):
2023-02-27 02:37:22 +00:00
self.progress_thread.join(self.join_timeout)
2023-02-27 02:09:42 +00:00
self.finished_thread.join(self.join_timeout)
2023-02-26 16:47:31 +00:00
for device, worker in self.workers.items():
if worker.is_alive():
logger.info("stopping worker for device %s", device)
worker.join(self.join_timeout)
2023-02-26 16:47:31 +00:00
2023-02-26 17:16:33 +00:00
if self.logger.is_alive():
self.logger.join(self.join_timeout)
2023-02-26 05:49:39 +00:00
2023-02-26 18:58:38 +00:00
def recycle(self):
for name, proc in self.workers.items():
if proc.is_alive():
logger.debug("shutting down worker for device %s", name)
proc.join(self.join_timeout)
2023-02-26 19:09:24 +00:00
proc.terminate()
2023-02-26 18:58:38 +00:00
else:
logger.warning("worker for device %s has died", name)
self.workers[name] = None
2023-02-26 19:09:24 +00:00
del proc
2023-02-26 18:58:38 +00:00
logger.info("starting new workers")
for device in self.devices:
self.create_device_worker(device)
2023-02-26 18:58:38 +00:00
2023-02-26 05:49:39 +00:00
def submit(
self,
key: str,
fn: Callable[..., None],
/,
*args,
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
2023-02-27 02:09:42 +00:00
self.total_jobs += 1
logger.debug("pool job count: %s", self.total_jobs)
if self.total_jobs > self.max_jobs_per_worker:
2023-02-26 18:58:38 +00:00
self.recycle()
2023-02-27 02:09:42 +00:00
self.total_jobs = 0
2023-02-26 18:58:38 +00:00
2023-02-26 05:49:39 +00:00
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
2023-02-26 20:15:30 +00:00
"assigning job %s to device %s: %s",
key,
device_idx,
self.devices[device_idx],
2023-02-26 05:49:39 +00:00
)
device = self.devices[device_idx]
queue = self.pending[device.device]
queue.put((fn, args, kwargs))
2023-02-27 02:09:42 +00:00
self.active_jobs[key] = device.device
2023-02-26 05:49:39 +00:00
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
(
name,
self.workers[name].is_alive(),
2023-02-27 02:37:22 +00:00
self.context[device].pending.qsize(),
self.context[device].cancel.value,
2023-02-27 02:09:42 +00:00
False,
2023-02-27 02:37:22 +00:00
progress,
2023-02-26 05:49:39 +00:00
)
2023-02-27 02:37:22 +00:00
for name, device, progress in self.active_jobs
2023-02-26 05:49:39 +00:00
]
2023-02-27 02:09:42 +00:00
pending.extend(
[
(
name,
False,
0,
cancel,
True,
progress,
)
for name, progress, cancel in self.finished_jobs
]
)
2023-02-26 05:49:39 +00:00
return pending