2023-02-26 05:49:39 +00:00
|
|
|
from collections import Counter
|
|
|
|
from logging import getLogger
|
2023-02-28 04:45:29 +00:00
|
|
|
from queue import Empty
|
2023-02-27 02:09:42 +00:00
|
|
|
from threading import Thread
|
2023-03-01 03:44:52 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
2023-02-26 05:49:39 +00:00
|
|
|
|
2023-02-27 02:09:42 +00:00
|
|
|
from torch.multiprocessing import Process, Queue, Value
|
2023-02-26 20:15:30 +00:00
|
|
|
|
2023-02-26 05:49:39 +00:00
|
|
|
from ..params import DeviceParams
|
2023-02-26 18:32:48 +00:00
|
|
|
from ..server import ServerContext
|
2023-02-26 05:49:39 +00:00
|
|
|
from .context import WorkerContext
|
2023-02-27 23:14:53 +00:00
|
|
|
from .worker import worker_main
|
2023-02-26 05:49:39 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class DevicePoolExecutor:
|
2023-03-01 03:44:52 +00:00
|
|
|
server: ServerContext
|
|
|
|
devices: List[DeviceParams]
|
|
|
|
max_jobs_per_worker: int
|
|
|
|
max_pending_per_worker: int
|
|
|
|
join_timeout: float
|
|
|
|
|
|
|
|
context: Dict[str, WorkerContext] # Device -> Context
|
|
|
|
pending: Dict[str, "Queue[Tuple[str, Callable[..., None], Any, Any]]"]
|
|
|
|
threads: Dict[str, Thread]
|
|
|
|
workers: Dict[str, Process]
|
|
|
|
|
|
|
|
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
|
|
|
|
cancelled_jobs: List[str]
|
|
|
|
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
|
|
|
|
total_jobs: int
|
|
|
|
|
|
|
|
logs: "Queue"
|
|
|
|
progress: "Queue[Tuple[str, str, int]]"
|
|
|
|
finished: "Queue[Tuple[str, str]]"
|
2023-02-26 05:49:39 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
2023-02-26 18:32:48 +00:00
|
|
|
server: ServerContext,
|
2023-02-26 05:49:39 +00:00
|
|
|
devices: List[DeviceParams],
|
2023-02-26 21:06:40 +00:00
|
|
|
max_jobs_per_worker: int = 10,
|
2023-03-01 03:44:52 +00:00
|
|
|
max_pending_per_worker: int = 100,
|
2023-02-27 23:14:53 +00:00
|
|
|
join_timeout: float = 1.0,
|
2023-02-26 05:49:39 +00:00
|
|
|
):
|
2023-02-26 18:32:48 +00:00
|
|
|
self.server = server
|
2023-02-26 05:49:39 +00:00
|
|
|
self.devices = devices
|
2023-02-26 21:06:40 +00:00
|
|
|
self.max_jobs_per_worker = max_jobs_per_worker
|
2023-03-01 03:44:52 +00:00
|
|
|
self.max_pending_per_worker = max_pending_per_worker
|
2023-02-26 21:06:40 +00:00
|
|
|
self.join_timeout = join_timeout
|
|
|
|
|
2023-02-26 18:24:51 +00:00
|
|
|
self.context = {}
|
2023-02-26 05:49:39 +00:00
|
|
|
self.pending = {}
|
2023-02-27 23:14:53 +00:00
|
|
|
self.threads = {}
|
2023-02-26 05:49:39 +00:00
|
|
|
self.workers = {}
|
2023-02-27 23:14:53 +00:00
|
|
|
|
2023-02-27 02:09:42 +00:00
|
|
|
self.active_jobs = {}
|
2023-02-27 23:14:53 +00:00
|
|
|
self.cancelled_jobs = []
|
2023-02-27 02:09:42 +00:00
|
|
|
self.finished_jobs = []
|
|
|
|
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
|
|
|
|
|
2023-03-01 03:44:52 +00:00
|
|
|
self.logs = Queue(self.max_pending_per_worker)
|
|
|
|
self.progress = Queue(self.max_pending_per_worker)
|
|
|
|
self.finished = Queue(self.max_pending_per_worker)
|
2023-02-26 05:49:39 +00:00
|
|
|
|
2023-02-26 21:06:40 +00:00
|
|
|
self.create_logger_worker()
|
2023-02-27 23:14:53 +00:00
|
|
|
self.create_progress_worker()
|
|
|
|
self.create_finished_worker()
|
|
|
|
|
2023-02-26 21:06:40 +00:00
|
|
|
for device in devices:
|
|
|
|
self.create_device_worker(device)
|
2023-02-26 05:49:39 +00:00
|
|
|
|
2023-02-26 17:16:33 +00:00
|
|
|
logger.debug("testing log worker")
|
2023-02-27 23:14:53 +00:00
|
|
|
self.logs.put("testing")
|
2023-02-26 21:06:40 +00:00
|
|
|
|
|
|
|
def create_device_worker(self, device: DeviceParams) -> None:
|
|
|
|
name = device.device
|
2023-02-27 02:09:42 +00:00
|
|
|
|
|
|
|
# reuse the queue if possible, to keep queued jobs
|
|
|
|
if name in self.pending:
|
2023-02-28 05:43:38 +00:00
|
|
|
logger.debug("using existing pending job queue")
|
2023-02-27 02:09:42 +00:00
|
|
|
pending = self.pending[name]
|
|
|
|
else:
|
2023-02-28 05:43:38 +00:00
|
|
|
logger.debug("creating new pending job queue")
|
2023-03-01 03:44:52 +00:00
|
|
|
pending = Queue(self.max_pending_per_worker)
|
2023-02-27 02:09:42 +00:00
|
|
|
self.pending[name] = pending
|
|
|
|
|
2023-02-26 21:06:40 +00:00
|
|
|
context = WorkerContext(
|
|
|
|
name,
|
|
|
|
device,
|
|
|
|
cancel=Value("B", False),
|
2023-02-27 02:37:22 +00:00
|
|
|
progress=self.progress,
|
2023-02-27 02:09:42 +00:00
|
|
|
finished=self.finished,
|
2023-02-27 23:14:53 +00:00
|
|
|
logs=self.logs,
|
2023-02-27 02:09:42 +00:00
|
|
|
pending=pending,
|
2023-02-26 21:06:40 +00:00
|
|
|
)
|
|
|
|
self.context[name] = context
|
2023-03-01 03:44:52 +00:00
|
|
|
self.workers[name] = Process(
|
|
|
|
name=f"onnx-web worker: {name}",
|
|
|
|
target=worker_main,
|
|
|
|
args=(context, self.server),
|
|
|
|
)
|
2023-02-26 21:06:40 +00:00
|
|
|
|
|
|
|
logger.debug("starting worker for device %s", device)
|
|
|
|
self.workers[name].start()
|
|
|
|
|
2023-02-27 23:14:53 +00:00
|
|
|
def create_logger_worker(self) -> None:
|
|
|
|
def logger_worker(logs: Queue):
|
|
|
|
logger.info("checking in from logger worker thread")
|
|
|
|
|
|
|
|
while True:
|
2023-02-28 01:48:51 +00:00
|
|
|
try:
|
2023-02-28 04:37:43 +00:00
|
|
|
job = logs.get(timeout=(self.join_timeout / 2))
|
2023-02-28 01:48:51 +00:00
|
|
|
with open("worker.log", "w") as f:
|
|
|
|
logger.info("got log: %s", job)
|
|
|
|
f.write(str(job) + "\n\n")
|
2023-02-28 04:45:29 +00:00
|
|
|
except Empty:
|
|
|
|
pass
|
2023-02-28 05:12:53 +00:00
|
|
|
except ValueError:
|
|
|
|
break
|
2023-02-28 01:48:51 +00:00
|
|
|
except Exception as err:
|
|
|
|
logger.error("error in log worker: %s", err)
|
2023-02-27 23:14:53 +00:00
|
|
|
|
2023-03-01 03:44:52 +00:00
|
|
|
logger_thread = Thread(
|
|
|
|
name="onnx-web logger", target=logger_worker, args=(self.logs,), daemon=True
|
|
|
|
)
|
2023-02-27 23:14:53 +00:00
|
|
|
self.threads["logger"] = logger_thread
|
|
|
|
|
|
|
|
logger.debug("starting logger worker")
|
|
|
|
logger_thread.start()
|
|
|
|
|
|
|
|
def create_progress_worker(self) -> None:
|
2023-02-27 02:37:22 +00:00
|
|
|
def progress_worker(progress: Queue):
|
|
|
|
logger.info("checking in from progress worker thread")
|
2023-02-27 02:09:42 +00:00
|
|
|
while True:
|
2023-02-27 23:14:53 +00:00
|
|
|
try:
|
2023-02-28 04:37:43 +00:00
|
|
|
job, device, value = progress.get(timeout=(self.join_timeout / 2))
|
2023-02-27 23:14:53 +00:00
|
|
|
logger.info("progress update for job: %s to %s", job, value)
|
|
|
|
self.active_jobs[job] = (device, value)
|
|
|
|
if job in self.cancelled_jobs:
|
|
|
|
logger.debug(
|
|
|
|
"setting flag for cancelled job: %s on %s", job, device
|
|
|
|
)
|
|
|
|
self.context[device].set_cancel()
|
2023-02-28 04:45:29 +00:00
|
|
|
except Empty:
|
|
|
|
pass
|
2023-02-28 05:12:53 +00:00
|
|
|
except ValueError:
|
|
|
|
break
|
2023-02-27 23:14:53 +00:00
|
|
|
except Exception as err:
|
2023-02-28 01:48:51 +00:00
|
|
|
logger.error("error in progress worker: %s", err)
|
2023-02-27 23:14:53 +00:00
|
|
|
|
2023-03-01 03:44:52 +00:00
|
|
|
progress_thread = Thread(
|
|
|
|
name="onnx-web progress",
|
|
|
|
target=progress_worker,
|
|
|
|
args=(self.progress,),
|
|
|
|
daemon=True,
|
|
|
|
)
|
2023-02-27 23:14:53 +00:00
|
|
|
self.threads["progress"] = progress_thread
|
|
|
|
|
|
|
|
logger.debug("starting progress worker")
|
|
|
|
progress_thread.start()
|
|
|
|
|
|
|
|
def create_finished_worker(self) -> None:
|
2023-02-27 02:09:42 +00:00
|
|
|
def finished_worker(finished: Queue):
|
2023-02-27 02:37:22 +00:00
|
|
|
logger.info("checking in from finished worker thread")
|
2023-02-27 02:09:42 +00:00
|
|
|
while True:
|
2023-02-28 01:48:51 +00:00
|
|
|
try:
|
2023-02-28 04:37:43 +00:00
|
|
|
job, device = finished.get(timeout=(self.join_timeout / 2))
|
2023-02-28 01:48:51 +00:00
|
|
|
logger.info("job has been finished: %s", job)
|
|
|
|
context = self.context[device]
|
|
|
|
_device, progress = self.active_jobs[job]
|
|
|
|
self.finished_jobs.append((job, progress, context.cancel.value))
|
|
|
|
del self.active_jobs[job]
|
2023-02-28 04:45:29 +00:00
|
|
|
except Empty:
|
|
|
|
pass
|
2023-02-28 05:12:53 +00:00
|
|
|
except ValueError:
|
|
|
|
break
|
2023-02-28 01:48:51 +00:00
|
|
|
except Exception as err:
|
|
|
|
logger.error("error in finished worker: %s", err)
|
2023-02-27 02:09:42 +00:00
|
|
|
|
2023-03-01 03:44:52 +00:00
|
|
|
finished_thread = Thread(
|
|
|
|
name="onnx-web finished",
|
|
|
|
target=finished_worker,
|
|
|
|
args=(self.finished,),
|
|
|
|
daemon=True,
|
|
|
|
)
|
2023-02-27 23:36:26 +00:00
|
|
|
self.threads["finished"] = finished_thread
|
2023-02-27 23:14:53 +00:00
|
|
|
|
|
|
|
logger.debug("started finished worker")
|
|
|
|
finished_thread.start()
|
2023-02-27 02:09:42 +00:00
|
|
|
|
|
|
|
def get_job_context(self, key: str) -> WorkerContext:
|
2023-02-27 02:37:22 +00:00
|
|
|
device, _progress = self.active_jobs[key]
|
2023-02-27 02:09:42 +00:00
|
|
|
return self.context[device]
|
2023-02-26 05:49:39 +00:00
|
|
|
|
2023-02-27 23:14:53 +00:00
|
|
|
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
|
|
|
# respect overrides if possible
|
|
|
|
if needs_device is not None:
|
|
|
|
for i in range(len(self.devices)):
|
|
|
|
if self.devices[i].device == needs_device.device:
|
|
|
|
return i
|
|
|
|
|
|
|
|
jobs = Counter(range(len(self.devices)))
|
|
|
|
jobs.update([self.pending[d.device].qsize() for d in self.devices])
|
|
|
|
|
|
|
|
queued = jobs.most_common()
|
|
|
|
logger.debug("jobs queued by device: %s", queued)
|
|
|
|
|
|
|
|
lowest_count = queued[-1][1]
|
|
|
|
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
|
|
|
lowest_devices.sort()
|
|
|
|
|
|
|
|
return lowest_devices[0]
|
|
|
|
|
2023-02-26 05:49:39 +00:00
|
|
|
def cancel(self, key: str) -> bool:
|
|
|
|
"""
|
|
|
|
Cancel a job. If the job has not been started, this will cancel
|
|
|
|
the future and never execute it. If the job has been started, it
|
|
|
|
should be cancelled on the next progress callback.
|
|
|
|
"""
|
2023-02-27 23:14:53 +00:00
|
|
|
|
|
|
|
self.cancelled_jobs.append(key)
|
|
|
|
|
2023-02-27 02:09:42 +00:00
|
|
|
if key not in self.active_jobs:
|
2023-02-27 23:14:53 +00:00
|
|
|
logger.debug("cancelled job has not been started yet: %s", key)
|
|
|
|
return True
|
2023-02-26 20:36:32 +00:00
|
|
|
|
2023-02-27 02:41:16 +00:00
|
|
|
device, _progress = self.active_jobs[key]
|
2023-02-27 23:14:53 +00:00
|
|
|
logger.info("cancelling job %s, active on device %s", key, device)
|
2023-02-26 20:36:32 +00:00
|
|
|
|
2023-02-27 23:14:53 +00:00
|
|
|
context = self.context[device]
|
|
|
|
context.set_cancel()
|
2023-02-26 20:36:32 +00:00
|
|
|
|
|
|
|
return True
|
2023-02-26 05:49:39 +00:00
|
|
|
|
|
|
|
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
2023-02-27 23:14:53 +00:00
|
|
|
"""
|
|
|
|
Check if a job has been finished and report the last progress update.
|
|
|
|
|
|
|
|
If the job is still active or pending, the first item will be False.
|
|
|
|
If the job is not finished or active, the first item will be None.
|
|
|
|
"""
|
2023-02-27 02:09:42 +00:00
|
|
|
for k, p, c in self.finished_jobs:
|
|
|
|
if k == key:
|
2023-02-27 02:13:16 +00:00
|
|
|
return (True, p)
|
2023-02-27 02:09:42 +00:00
|
|
|
|
|
|
|
if key not in self.active_jobs:
|
2023-02-26 20:36:32 +00:00
|
|
|
logger.warn("checking status for unknown job: %s", key)
|
2023-02-26 18:51:11 +00:00
|
|
|
return (None, 0)
|
|
|
|
|
2023-02-27 02:37:22 +00:00
|
|
|
_device, progress = self.active_jobs[key]
|
|
|
|
return (False, progress)
|
2023-02-26 05:49:39 +00:00
|
|
|
|
2023-02-26 16:47:31 +00:00
|
|
|
def join(self):
|
2023-03-01 03:44:52 +00:00
|
|
|
logger.info("stopping worker pool")
|
|
|
|
|
|
|
|
logger.debug("closing queues")
|
|
|
|
self.logs.close()
|
|
|
|
self.finished.close()
|
|
|
|
self.progress.close()
|
|
|
|
for queue in self.pending.values():
|
|
|
|
queue.close()
|
2023-02-28 05:01:26 +00:00
|
|
|
|
2023-03-01 03:44:52 +00:00
|
|
|
self.pending.clear()
|
|
|
|
|
|
|
|
logger.debug("stopping device workers")
|
2023-02-26 16:47:31 +00:00
|
|
|
for device, worker in self.workers.items():
|
|
|
|
if worker.is_alive():
|
2023-02-28 05:01:26 +00:00
|
|
|
logger.debug("stopping worker for device %s", device)
|
2023-02-26 21:06:40 +00:00
|
|
|
worker.join(self.join_timeout)
|
2023-02-28 14:53:17 +00:00
|
|
|
# worker.terminate()
|
2023-02-28 05:01:26 +00:00
|
|
|
else:
|
|
|
|
logger.debug("worker for device %s has died", device)
|
2023-02-26 16:47:31 +00:00
|
|
|
|
2023-02-27 23:35:31 +00:00
|
|
|
for name, thread in self.threads.items():
|
2023-02-28 05:01:26 +00:00
|
|
|
logger.debug("stopping worker thread: %s", name)
|
2023-02-27 23:35:31 +00:00
|
|
|
thread.join(self.join_timeout)
|
2023-02-26 05:49:39 +00:00
|
|
|
|
2023-02-28 05:01:26 +00:00
|
|
|
logger.debug("worker pool fully joined")
|
|
|
|
|
2023-02-26 18:58:38 +00:00
|
|
|
def recycle(self):
|
|
|
|
for name, proc in self.workers.items():
|
|
|
|
if proc.is_alive():
|
|
|
|
logger.debug("shutting down worker for device %s", name)
|
2023-02-26 21:06:40 +00:00
|
|
|
proc.join(self.join_timeout)
|
2023-02-28 14:53:17 +00:00
|
|
|
# proc.terminate()
|
2023-02-26 18:58:38 +00:00
|
|
|
else:
|
|
|
|
logger.warning("worker for device %s has died", name)
|
|
|
|
|
|
|
|
self.workers[name] = None
|
2023-02-26 19:09:24 +00:00
|
|
|
del proc
|
2023-02-26 18:58:38 +00:00
|
|
|
|
|
|
|
logger.info("starting new workers")
|
|
|
|
|
2023-02-26 21:06:40 +00:00
|
|
|
for device in self.devices:
|
|
|
|
self.create_device_worker(device)
|
2023-02-26 18:58:38 +00:00
|
|
|
|
2023-02-26 05:49:39 +00:00
|
|
|
def submit(
|
|
|
|
self,
|
|
|
|
key: str,
|
|
|
|
fn: Callable[..., None],
|
|
|
|
/,
|
|
|
|
*args,
|
|
|
|
needs_device: Optional[DeviceParams] = None,
|
|
|
|
**kwargs,
|
|
|
|
) -> None:
|
2023-02-27 02:09:42 +00:00
|
|
|
self.total_jobs += 1
|
|
|
|
logger.debug("pool job count: %s", self.total_jobs)
|
|
|
|
if self.total_jobs > self.max_jobs_per_worker:
|
2023-02-26 18:58:38 +00:00
|
|
|
self.recycle()
|
2023-02-27 02:09:42 +00:00
|
|
|
self.total_jobs = 0
|
2023-02-26 18:58:38 +00:00
|
|
|
|
2023-02-26 05:49:39 +00:00
|
|
|
device_idx = self.get_next_device(needs_device=needs_device)
|
|
|
|
logger.info(
|
2023-02-26 20:15:30 +00:00
|
|
|
"assigning job %s to device %s: %s",
|
|
|
|
key,
|
|
|
|
device_idx,
|
|
|
|
self.devices[device_idx],
|
2023-02-26 05:49:39 +00:00
|
|
|
)
|
|
|
|
|
2023-02-27 23:14:53 +00:00
|
|
|
device = self.devices[device_idx].device
|
2023-02-28 04:37:43 +00:00
|
|
|
self.pending[device].put((key, fn, args, kwargs), block=False)
|
2023-02-27 23:14:53 +00:00
|
|
|
|
|
|
|
def status(self) -> List[Tuple[str, int, bool, bool]]:
|
|
|
|
history = [
|
|
|
|
(name, progress, False, name in self.cancelled_jobs)
|
2023-03-01 03:44:52 +00:00
|
|
|
for name, (_device, progress) in self.active_jobs.items()
|
2023-02-26 05:49:39 +00:00
|
|
|
]
|
2023-02-27 23:14:53 +00:00
|
|
|
history.extend(
|
2023-02-27 02:09:42 +00:00
|
|
|
[
|
|
|
|
(
|
|
|
|
name,
|
|
|
|
progress,
|
2023-02-27 23:14:53 +00:00
|
|
|
True,
|
|
|
|
cancel,
|
2023-02-27 02:09:42 +00:00
|
|
|
)
|
|
|
|
for name, progress, cancel in self.finished_jobs
|
|
|
|
]
|
|
|
|
)
|
2023-02-27 23:14:53 +00:00
|
|
|
return history
|