1
0
Fork 0
onnx-web/api/onnx_web/worker/pool.py

574 lines
19 KiB
Python
Raw Normal View History

2023-02-26 05:49:39 +00:00
from collections import Counter
from logging import getLogger
2023-02-28 04:45:29 +00:00
from queue import Empty
2023-03-26 13:30:34 +00:00
from threading import Lock, Thread
from typing import Callable, Dict, List, Optional, Tuple
2023-02-26 05:49:39 +00:00
2023-02-27 02:09:42 +00:00
from torch.multiprocessing import Process, Queue, Value
2023-02-26 20:15:30 +00:00
2023-02-26 05:49:39 +00:00
from ..params import DeviceParams
2023-02-26 18:32:48 +00:00
from ..server import ServerContext
from .command import JobCommand, ProgressCommand
2023-02-26 05:49:39 +00:00
from .context import WorkerContext
from .utils import Interval
from .worker import worker_main
2023-02-26 05:49:39 +00:00
logger = getLogger(__name__)
class DevicePoolExecutor:
server: ServerContext
devices: List[DeviceParams]
join_timeout: float
max_jobs_per_worker: int
max_pending_per_worker: int
progress_interval: float
recycle_interval: float
leaking: List[Tuple[str, Process, WorkerContext]]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"] # Device -> pid
pending: Dict[str, "Queue[JobCommand]"]
progress: Dict[str, "Queue[ProgressCommand]"]
workers: Dict[str, Process]
health_worker: Interval
logger_worker: Thread
progress_worker: Interval
cancelled_jobs: List[str]
finished_jobs: List[ProgressCommand]
pending_jobs: List[JobCommand]
running_jobs: Dict[str, ProgressCommand] # Device -> job progress
total_jobs: Dict[str, int] # Device -> job count
logs: "Queue[str]"
rlock: Lock
2023-02-26 05:49:39 +00:00
def __init__(
self,
2023-02-26 18:32:48 +00:00
server: ServerContext,
2023-02-26 05:49:39 +00:00
devices: List[DeviceParams],
max_pending_per_worker: int = 100,
join_timeout: float = 1.0,
recycle_interval: float = 10,
progress_interval: float = 1.0,
2023-02-26 05:49:39 +00:00
):
2023-02-26 18:32:48 +00:00
self.server = server
2023-02-26 05:49:39 +00:00
self.devices = devices
self.join_timeout = join_timeout
self.max_jobs_per_worker = server.job_limit
self.max_pending_per_worker = max_pending_per_worker
self.progress_interval = progress_interval
self.recycle_interval = recycle_interval
self.leaking = []
2023-02-26 18:24:51 +00:00
self.context = {}
self.current = {}
2023-02-26 05:49:39 +00:00
self.pending = {}
self.progress = {}
2023-02-26 05:49:39 +00:00
self.workers = {}
self.cancelled_jobs = []
2023-02-27 02:09:42 +00:00
self.finished_jobs = []
2023-03-18 22:26:28 +00:00
self.pending_jobs = []
self.running_jobs = {}
self.total_jobs = {}
2023-02-27 02:09:42 +00:00
self.logs = Queue(self.max_pending_per_worker)
self.rlock = Lock()
2023-02-26 05:49:39 +00:00
def start(self) -> None:
self.create_health_worker()
self.create_logger_worker()
self.create_progress_worker()
for device in self.devices:
self.create_device_worker(device)
2023-02-26 05:49:39 +00:00
def create_device_worker(self, device: DeviceParams) -> None:
name = device.device
2023-02-27 02:09:42 +00:00
# always recreate queues
self.progress[name] = Queue(self.max_pending_per_worker)
self.pending[name] = Queue(self.max_pending_per_worker)
self.total_jobs[device.device] = 0
2023-02-27 02:09:42 +00:00
# reuse pid sentinel
if name in self.current:
logger.debug("using existing current worker value")
current = self.current[name]
else:
logger.debug("creating new current worker value")
current = Value("L", 0)
self.current[name] = current
# create a new context and worker
context = WorkerContext(
name,
device,
cancel=Value("B", False),
progress=self.progress[name],
logs=self.logs,
pending=self.pending[name],
active_pid=current,
)
self.context[name] = context
worker = Process(
name=f"onnx-web worker: {name}",
2023-03-23 03:58:46 +00:00
target=worker_main,
args=(context, self.server),
)
worker.daemon = True
self.workers[name] = worker
logger.debug("starting worker for device %s", device)
worker.start()
current.value = worker.pid
def create_health_worker(self) -> None:
self.health_worker = Interval(self.recycle_interval, health_main, args=(self,))
self.health_worker.daemon = True
self.health_worker.name = "onnx-web health"
logger.debug("starting health worker")
self.health_worker.start()
def create_logger_worker(self) -> None:
self.logger_worker = Thread(
2023-03-23 03:58:46 +00:00
name="onnx-web logger",
target=logger_main,
args=(
self,
self.logs,
),
daemon=True,
)
logger.debug("starting logger worker")
self.logger_worker.start()
def create_progress_worker(self) -> None:
self.progress_worker = Interval(
self.progress_interval, progress_main, args=(self,)
)
self.progress_worker.daemon = True
self.progress_worker.name = "onnx-web progress"
logger.debug("starting progress worker")
self.progress_worker.start()
2023-02-27 02:09:42 +00:00
def get_job_context(self, key: str) -> WorkerContext:
device, _progress = self.running_jobs[key]
2023-02-27 02:09:42 +00:00
return self.context[device]
2023-02-26 05:49:39 +00:00
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
if needs_device is not None:
for i in range(len(self.devices)):
if self.devices[i].device == needs_device.device:
return i
jobs = Counter(range(len(self.devices)))
jobs.update([self.pending[d.device].qsize() for d in self.devices])
queued = jobs.most_common()
2023-03-17 01:22:20 +00:00
logger.trace("jobs queued by device: %s", queued)
lowest_count = queued[-1][1]
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
lowest_devices.sort()
return lowest_devices[0]
2023-02-26 05:49:39 +00:00
def cancel(self, key: str) -> bool:
"""
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
for job in self.finished_jobs:
if job.job == key:
logger.debug("cannot cancel finished job: %s", key)
return False
for job in self.pending_jobs:
if job.name == key:
self.pending_jobs.remove(job)
logger.info("cancelled pending job: %s", key)
return True
if key not in self.running_jobs:
logger.debug("cancelled job is not active: %s", key)
else:
job = self.running_jobs[key]
logger.info("cancelling job %s, active on device %s", key, job.device)
2023-02-26 20:36:32 +00:00
self.cancelled_jobs.append(key)
2023-02-26 20:36:32 +00:00
return True
2023-02-26 05:49:39 +00:00
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]:
"""
Check if a job has been finished and report the last progress update.
If the job is still pending, the first item will be True and there will be no ProgressCommand.
"""
if key in self.running_jobs:
logger.debug("checking status for running job: %s", key)
return (False, self.running_jobs[key])
for job in self.finished_jobs:
if job.job == key:
logger.debug("checking status for finished job: %s", key)
return (False, job)
2023-02-27 02:09:42 +00:00
for job in self.pending_jobs:
if job.name == key:
logger.debug("checking status for pending job: %s", key)
return (True, None)
logger.trace("checking status for unknown job: %s", key)
return (False, None)
2023-02-26 05:49:39 +00:00
2023-02-26 16:47:31 +00:00
def join(self):
logger.info("stopping worker pool")
with self.rlock:
logger.debug("stopping device workers")
for device, worker in self.workers.items():
if worker.is_alive():
logger.debug("stopping worker %s for device %s", worker.pid, device)
worker.join(self.join_timeout)
if worker.is_alive():
logger.warning(
"worker %s for device %s could not be stopped in time",
worker.pid,
device,
)
self.leak_worker(device)
else:
logger.debug("worker for device %s has died", device)
2023-02-26 16:47:31 +00:00
logger.debug("stopping health worker")
self.health_worker.cancel()
self.health_worker.join(self.recycle_interval)
logger.debug("stopping progress worker")
self.progress_worker.cancel()
self.progress_worker.join(self.progress_interval)
logger.debug("stopping logger worker")
self.logger_worker.join(self.join_timeout)
logger.debug("closing worker queues")
self.logs.close()
for queue in self.pending.values():
queue.close()
for queue in self.progress.values():
queue.close()
self.pending.clear()
self.progress.clear()
logger.debug("stopping leaking workers")
self.join_leaking()
logger.debug("worker pool stopped")
def join_leaking(self):
if len(self.leaking) > 0:
for device, worker, context in self.leaking:
logger.warning(
"shutting down leaking worker %s for device %s", worker.pid, device
2023-03-07 14:02:53 +00:00
)
worker.join(self.join_timeout)
if worker.is_alive():
2023-03-06 01:23:23 +00:00
logger.error(
2023-03-07 14:02:53 +00:00
"leaking worker %s for device %s could not be shut down",
worker.pid,
device,
2023-03-06 01:23:23 +00:00
)
try:
progress = context.progress.get_nowait()
while progress is not None:
self.update_job(progress)
progress = context.progress.get_nowait()
except Empty:
logger.trace("empty queue in leaking worker for device %s", device)
except ValueError as e:
2023-03-26 23:24:22 +00:00
logger.debug(
"value error in leaking worker for device %s: %s", device, e
)
except Exception:
logger.exception("error in leaking worker for device %s", device)
self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()]
def recycle(self):
logger.debug("recycling worker pool")
with self.rlock:
self.join_leaking()
needs_restart = []
for device, worker in self.workers.items():
jobs = self.total_jobs.get(device, 0)
if not worker.is_alive():
logger.warning("worker for device %s has died", device)
needs_restart.append(device)
elif jobs > self.max_jobs_per_worker:
logger.info(
"shutting down worker for device %s after %s jobs", device, jobs
)
worker.join(self.join_timeout)
if worker.is_alive():
logger.warning(
"worker %s for device %s could not be recycled in time",
worker.pid,
device,
)
self.leak_worker(device)
else:
del worker
self.workers[device] = None
needs_restart.append(device)
else:
logger.debug(
2023-03-26 16:53:06 +00:00
"worker %s for device %s has run %s jobs and is still alive",
2023-03-07 14:02:53 +00:00
worker.pid,
device,
2023-03-26 16:53:06 +00:00
jobs,
)
2023-02-26 18:58:38 +00:00
if len(needs_restart) > 0:
logger.info("starting new workers")
2023-02-26 18:58:38 +00:00
for device in self.devices:
if device.device in needs_restart:
self.create_device_worker(device)
if self.logger_worker.is_alive():
logger.debug("logger worker is running")
else:
logger.warning("restarting logger worker")
self.create_logger_worker()
if self.progress_worker.is_alive():
logger.debug("progress worker is running")
else:
logger.warning("restarting progress worker")
self.create_progress_worker()
logger.debug("worker pool recycled")
2023-02-26 18:58:38 +00:00
2023-02-26 05:49:39 +00:00
def submit(
self,
key: str,
fn: Callable[..., None],
/,
*args,
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
device_idx = self.get_next_device(needs_device=needs_device)
device = self.devices[device_idx].device
2023-02-26 05:49:39 +00:00
logger.info(
2023-02-26 20:15:30 +00:00
"assigning job %s to device %s: %s",
key,
device_idx,
device,
2023-02-26 05:49:39 +00:00
)
# build and queue job
job = JobCommand(key, device, fn, args, kwargs)
self.pending_jobs.append(job)
def status(self) -> Dict[str, List[Tuple[str, int, bool, bool, bool, bool]]]:
"""
Returns a tuple of: job/device, progress, progress, finished, cancelled, failed
"""
return {
"cancelled": [],
"finished": [
(
job.job,
job.progress,
False,
job.finished,
job.cancelled,
job.failed,
)
for job in self.finished_jobs
],
"pending": [
(
job.name,
0,
True,
False,
False,
False,
)
for job in self.pending_jobs
],
"running": [
2023-02-27 02:09:42 +00:00
(
name,
job.progress,
False,
job.finished,
job.cancelled,
job.failed,
2023-02-27 02:09:42 +00:00
)
for name, job in self.running_jobs.items()
],
"total": [
(
device,
total,
self.workers[device].is_alive(),
False,
False,
False,
)
for device, total in self.total_jobs.items()
],
}
def next_job(self, device: str):
for job in self.pending_jobs:
if job.device == device:
logger.debug("enqueuing job %s on device %s", job.name, device)
self.pending[device].put(job, block=False)
# job will be removed from pending queue when progress is updated
return
logger.trace("no pending jobs for device %s", device)
2023-03-26 19:02:57 +00:00
def finish_job(self, progress: ProgressCommand):
# move from running to finished
logger.info("job has finished: %s", progress.job)
self.finished_jobs.append(progress)
if progress.job in self.running_jobs:
del self.running_jobs[progress.job]
self.join_leaking()
if progress.job in self.cancelled_jobs:
self.cancelled_jobs.remove(progress.job)
def update_job(self, progress: ProgressCommand):
if progress.finished:
2023-03-26 19:02:57 +00:00
return self.finish_job(progress)
2023-03-26 19:02:57 +00:00
# move from pending to running
logger.debug(
"progress update for job: %s to %s", progress.job, progress.progress
)
self.running_jobs[progress.job] = progress
self.pending_jobs[:] = [
job for job in self.pending_jobs if job.name != progress.job
]
# increment job counter if this is the start of a new job
if progress.progress == 0:
if progress.device in self.total_jobs:
self.total_jobs[progress.device] += 1
else:
self.total_jobs[progress.device] = 1
logger.debug(
2023-03-26 19:02:57 +00:00
"updating job count for device %s: %s",
progress.device,
self.total_jobs[progress.device],
)
2023-03-26 19:02:57 +00:00
# check if the job has been cancelled
if progress.job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s",
progress.job,
progress.device,
)
self.context[progress.device].set_cancel()
def leak_worker(self, device: str):
context = self.context[device]
worker = self.workers[device]
self.leaking.append((device, worker, context))
2023-03-23 03:58:46 +00:00
def health_main(pool: DevicePoolExecutor):
logger.trace("checking in from health worker thread")
pool.recycle()
if pool.logs.full():
logger.warning("logger queue is full, restarting worker")
pool.logger_worker.join(pool.join_timeout)
pool.create_logger_worker()
if any([queue.full() for queue in pool.progress.values()]):
logger.warning("progress queue is full, restarting worker")
pool.progress_worker.join(pool.join_timeout)
pool.create_progress_worker()
def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"):
logger.trace("checking in from logger worker thread")
while True:
try:
msg = logs.get(pool.join_timeout / 2)
logger.debug("received logs from worker: %s", msg)
except Empty:
2023-03-26 20:53:20 +00:00
# logger worker should not generate more logs if it doesn't have any logs
pass
except ValueError:
break
except Exception:
logger.exception("error in log worker")
2023-03-23 03:58:46 +00:00
def progress_main(pool: DevicePoolExecutor):
logger.trace("checking in from progress worker thread")
for device, _worker, context in pool.leaking:
# whether the worker is alive or not, try to clear its queues
try:
progress = context.progress.get_nowait()
while progress is not None:
pool.update_job(progress)
progress = context.progress.get_nowait()
except Empty:
logger.trace("empty queue in leaking worker for device %s", device)
except ValueError as e:
logger.debug("value error in leaking worker for device %s: %s", device, e)
break
except Exception:
logger.exception("error in leaking worker for device %s", device)
for device, queue in pool.progress.items():
try:
progress = queue.get_nowait()
while progress is not None:
pool.update_job(progress)
progress = queue.get_nowait()
except Empty:
logger.trace("empty queue in progress worker for device %s", device)
except ValueError as e:
logger.debug("value error in progress worker for device %s: %s", device, e)
break
except Exception:
logger.exception("error in progress worker for device %s", device)
for device, queue in pool.pending.items():
if queue.empty():
logger.trace("enqueueing next job for idle worker")
pool.next_job(device)