1
0
Fork 0
onnx-web/api/onnx_web/worker/pool.py

299 lines
9.9 KiB
Python
Raw Normal View History

2023-02-26 05:49:39 +00:00
from collections import Counter
from logging import getLogger
2023-02-28 04:45:29 +00:00
from queue import Empty
2023-02-27 02:09:42 +00:00
from threading import Thread
2023-02-26 05:49:39 +00:00
from typing import Callable, Dict, List, Optional, Tuple
2023-02-27 02:09:42 +00:00
from torch.multiprocessing import Process, Queue, Value
2023-02-26 20:15:30 +00:00
2023-02-26 05:49:39 +00:00
from ..params import DeviceParams
2023-02-26 18:32:48 +00:00
from ..server import ServerContext
2023-02-26 05:49:39 +00:00
from .context import WorkerContext
from .worker import worker_main
2023-02-26 05:49:39 +00:00
logger = getLogger(__name__)
class DevicePoolExecutor:
2023-02-27 02:09:42 +00:00
context: Dict[str, WorkerContext] = None # Device -> Context
2023-02-26 05:49:39 +00:00
devices: List[DeviceParams] = None
pending: Dict[str, "Queue[WorkerContext]"] = None
workers: Dict[str, Process] = None
active_jobs: Dict[str, Tuple[str, int]] = None # should be Dict[Device, JobStatus]
finished_jobs: List[Tuple[str, int, bool]] = None # should be List[JobStatus]
2023-02-26 05:49:39 +00:00
def __init__(
self,
2023-02-26 18:32:48 +00:00
server: ServerContext,
2023-02-26 05:49:39 +00:00
devices: List[DeviceParams],
max_jobs_per_worker: int = 10,
join_timeout: float = 1.0,
2023-02-26 05:49:39 +00:00
):
2023-02-26 18:32:48 +00:00
self.server = server
2023-02-26 05:49:39 +00:00
self.devices = devices
self.max_jobs_per_worker = max_jobs_per_worker
self.join_timeout = join_timeout
2023-02-26 18:24:51 +00:00
self.context = {}
2023-02-26 05:49:39 +00:00
self.pending = {}
self.threads = {}
2023-02-26 05:49:39 +00:00
self.workers = {}
2023-02-27 02:09:42 +00:00
self.active_jobs = {}
self.cancelled_jobs = []
2023-02-27 02:09:42 +00:00
self.finished_jobs = []
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
self.logs = Queue()
2023-02-27 02:37:22 +00:00
self.progress = Queue()
2023-02-27 02:09:42 +00:00
self.finished = Queue()
2023-02-26 05:49:39 +00:00
self.create_logger_worker()
self.create_progress_worker()
self.create_finished_worker()
for device in devices:
self.create_device_worker(device)
2023-02-26 05:49:39 +00:00
2023-02-26 17:16:33 +00:00
logger.debug("testing log worker")
self.logs.put("testing")
def create_device_worker(self, device: DeviceParams) -> None:
name = device.device
2023-02-27 02:09:42 +00:00
# reuse the queue if possible, to keep queued jobs
if name in self.pending:
pending = self.pending[name]
else:
pending = Queue()
self.pending[name] = pending
context = WorkerContext(
name,
device,
cancel=Value("B", False),
2023-02-27 02:37:22 +00:00
progress=self.progress,
2023-02-27 02:09:42 +00:00
finished=self.finished,
logs=self.logs,
2023-02-27 02:09:42 +00:00
pending=pending,
)
self.context[name] = context
self.workers[name] = Process(target=worker_main, args=(context, self.server))
logger.debug("starting worker for device %s", device)
self.workers[name].start()
def create_logger_worker(self) -> None:
def logger_worker(logs: Queue):
logger.info("checking in from logger worker thread")
while True:
2023-02-28 01:48:51 +00:00
try:
2023-02-28 04:37:43 +00:00
job = logs.get(timeout=(self.join_timeout / 2))
2023-02-28 01:48:51 +00:00
with open("worker.log", "w") as f:
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
2023-02-28 04:45:29 +00:00
except Empty:
pass
2023-02-28 01:48:51 +00:00
except Exception as err:
logger.error("error in log worker: %s", err)
logger_thread = Thread(target=logger_worker, args=(self.logs,))
self.threads["logger"] = logger_thread
logger.debug("starting logger worker")
logger_thread.start()
def create_progress_worker(self) -> None:
2023-02-27 02:37:22 +00:00
def progress_worker(progress: Queue):
logger.info("checking in from progress worker thread")
2023-02-27 02:09:42 +00:00
while True:
try:
2023-02-28 04:37:43 +00:00
job, device, value = progress.get(timeout=(self.join_timeout / 2))
logger.info("progress update for job: %s to %s", job, value)
self.active_jobs[job] = (device, value)
if job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s", job, device
)
self.context[device].set_cancel()
2023-02-28 04:45:29 +00:00
except Empty:
pass
except Exception as err:
2023-02-28 01:48:51 +00:00
logger.error("error in progress worker: %s", err)
progress_thread = Thread(target=progress_worker, args=(self.progress,))
self.threads["progress"] = progress_thread
logger.debug("starting progress worker")
progress_thread.start()
def create_finished_worker(self) -> None:
2023-02-27 02:09:42 +00:00
def finished_worker(finished: Queue):
2023-02-27 02:37:22 +00:00
logger.info("checking in from finished worker thread")
2023-02-27 02:09:42 +00:00
while True:
2023-02-28 01:48:51 +00:00
try:
2023-02-28 04:37:43 +00:00
job, device = finished.get(timeout=(self.join_timeout / 2))
2023-02-28 01:48:51 +00:00
logger.info("job has been finished: %s", job)
context = self.context[device]
_device, progress = self.active_jobs[job]
self.finished_jobs.append((job, progress, context.cancel.value))
del self.active_jobs[job]
2023-02-28 04:45:29 +00:00
except Empty:
pass
2023-02-28 01:48:51 +00:00
except Exception as err:
logger.error("error in finished worker: %s", err)
2023-02-27 02:09:42 +00:00
finished_thread = Thread(target=finished_worker, args=(self.finished,))
2023-02-27 23:36:26 +00:00
self.threads["finished"] = finished_thread
logger.debug("started finished worker")
finished_thread.start()
2023-02-27 02:09:42 +00:00
def get_job_context(self, key: str) -> WorkerContext:
2023-02-27 02:37:22 +00:00
device, _progress = self.active_jobs[key]
2023-02-27 02:09:42 +00:00
return self.context[device]
2023-02-26 05:49:39 +00:00
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in range(len(self.devices)):
if self.devices[i].device == needs_device.device:
return i
jobs = Counter(range(len(self.devices)))
jobs.update([self.pending[d.device].qsize() for d in self.devices])
queued = jobs.most_common()
logger.debug("jobs queued by device: %s", queued)
lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
lowest_devices.sort()
return lowest_devices[0]
2023-02-26 05:49:39 +00:00
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
self.cancelled_jobs.append(key)
2023-02-27 02:09:42 +00:00
if key not in self.active_jobs:
logger.debug("cancelled job has not been started yet: %s", key)
return True
2023-02-26 20:36:32 +00:00
2023-02-27 02:41:16 +00:00
device, _progress = self.active_jobs[key]
logger.info("cancelling job %s, active on device %s", key, device)
2023-02-26 20:36:32 +00:00
context = self.context[device]
context.set_cancel()
2023-02-26 20:36:32 +00:00
return True
2023-02-26 05:49:39 +00:00
def done(self, key: str) -> Tuple[Optional[bool], int]:
"""
Check if a job has been finished and report the last progress update.
If the job is still active or pending, the first item will be False.
If the job is not finished or active, the first item will be None.
"""
2023-02-27 02:09:42 +00:00
for k, p, c in self.finished_jobs:
if k == key:
2023-02-27 02:13:16 +00:00
return (True, p)
2023-02-27 02:09:42 +00:00
if key not in self.active_jobs:
2023-02-26 20:36:32 +00:00
logger.warn("checking status for unknown job: %s", key)
return (None, 0)
2023-02-27 02:37:22 +00:00
_device, progress = self.active_jobs[key]
return (False, progress)
2023-02-26 05:49:39 +00:00
2023-02-26 16:47:31 +00:00
def join(self):
logger.debug("stopping worker pool")
2023-02-26 16:47:31 +00:00
for device, worker in self.workers.items():
if worker.is_alive():
logger.debug("stopping worker for device %s", device)
worker.join(self.join_timeout)
else:
logger.debug("worker for device %s has died", device)
2023-02-26 16:47:31 +00:00
2023-02-27 23:35:31 +00:00
for name, thread in self.threads.items():
logger.debug("stopping worker thread: %s", name)
2023-02-27 23:35:31 +00:00
thread.join(self.join_timeout)
2023-02-26 05:49:39 +00:00
logger.debug("closing queues")
self.logs.close()
self.finished.close()
self.progress.close()
for key, queue in self.pending.items():
queue.close()
del self.pending[key]
logger.debug("worker pool fully joined")
2023-02-26 18:58:38 +00:00
def recycle(self):
for name, proc in self.workers.items():
if proc.is_alive():
logger.debug("shutting down worker for device %s", name)
proc.join(self.join_timeout)
2023-02-26 19:09:24 +00:00
proc.terminate()
2023-02-26 18:58:38 +00:00
else:
logger.warning("worker for device %s has died", name)
self.workers[name] = None
2023-02-26 19:09:24 +00:00
del proc
2023-02-26 18:58:38 +00:00
logger.info("starting new workers")
for device in self.devices:
self.create_device_worker(device)
2023-02-26 18:58:38 +00:00
2023-02-26 05:49:39 +00:00
def submit(
self,
key: str,
fn: Callable[..., None],
/,
*args,
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
2023-02-27 02:09:42 +00:00
self.total_jobs += 1
logger.debug("pool job count: %s", self.total_jobs)
if self.total_jobs > self.max_jobs_per_worker:
2023-02-26 18:58:38 +00:00
self.recycle()
2023-02-27 02:09:42 +00:00
self.total_jobs = 0
2023-02-26 18:58:38 +00:00
2023-02-26 05:49:39 +00:00
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
2023-02-26 20:15:30 +00:00
"assigning job %s to device %s: %s",
key,
device_idx,
self.devices[device_idx],
2023-02-26 05:49:39 +00:00
)
device = self.devices[device_idx].device
2023-02-28 04:37:43 +00:00
self.pending[device].put((key, fn, args, kwargs), block=False)
def status(self) -> List[Tuple[str, int, bool, bool]]:
history = [
(name, progress, False, name in self.cancelled_jobs)
for name, _device, progress in self.active_jobs.items()
2023-02-26 05:49:39 +00:00
]
history.extend(
2023-02-27 02:09:42 +00:00
[
(
name,
progress,
True,
cancel,
2023-02-27 02:09:42 +00:00
)
for name, progress, cancel in self.finished_jobs
]
)
return history