fix(api): use lock when restarting workers
This commit is contained in:
parent
2c47904057
commit
88f4713e23
|
@ -1,7 +1,7 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from queue import Empty
|
from queue import Empty
|
||||||
from threading import Thread
|
from threading import Thread, Lock
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from torch.multiprocessing import Process, Queue, Value
|
from torch.multiprocessing import Process, Queue, Value
|
||||||
|
@ -37,6 +37,7 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
logs: "Queue[str]"
|
logs: "Queue[str]"
|
||||||
progress: "Queue[ProgressCommand]"
|
progress: "Queue[ProgressCommand]"
|
||||||
|
rlock: Lock
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -66,6 +67,7 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
self.logs = Queue(self.max_pending_per_worker)
|
self.logs = Queue(self.max_pending_per_worker)
|
||||||
self.progress = Queue(self.max_pending_per_worker)
|
self.progress = Queue(self.max_pending_per_worker)
|
||||||
|
self.rlock = Lock()
|
||||||
|
|
||||||
# TODO: these should be part of a start method
|
# TODO: these should be part of a start method
|
||||||
self.create_logger_worker()
|
self.create_logger_worker()
|
||||||
|
@ -223,35 +225,36 @@ class DevicePoolExecutor:
|
||||||
def join(self):
|
def join(self):
|
||||||
logger.info("stopping worker pool")
|
logger.info("stopping worker pool")
|
||||||
|
|
||||||
logger.debug("closing queues")
|
with self.rlock:
|
||||||
self.logs.close()
|
logger.debug("closing queues")
|
||||||
self.progress.close()
|
self.logs.close()
|
||||||
for queue in self.pending.values():
|
self.progress.close()
|
||||||
queue.close()
|
for queue in self.pending.values():
|
||||||
|
queue.close()
|
||||||
|
|
||||||
self.pending.clear()
|
self.pending.clear()
|
||||||
self.join_leaking()
|
self.join_leaking()
|
||||||
|
|
||||||
logger.debug("stopping device workers")
|
logger.debug("stopping device workers")
|
||||||
for device, worker in self.workers.items():
|
for device, worker in self.workers.items():
|
||||||
if worker.is_alive():
|
|
||||||
logger.debug("stopping worker %s for device %s", worker.pid, device)
|
|
||||||
worker.join(self.join_timeout)
|
|
||||||
if worker.is_alive():
|
if worker.is_alive():
|
||||||
logger.warning(
|
logger.debug("stopping worker %s for device %s", worker.pid, device)
|
||||||
"worker %s for device %s could not be stopped in time",
|
worker.join(self.join_timeout)
|
||||||
worker.pid,
|
if worker.is_alive():
|
||||||
device,
|
logger.warning(
|
||||||
)
|
"worker %s for device %s could not be stopped in time",
|
||||||
self.leaking.append((device, worker))
|
worker.pid,
|
||||||
else:
|
device,
|
||||||
logger.debug("worker for device %s has died", device)
|
)
|
||||||
|
self.leaking.append((device, worker))
|
||||||
|
else:
|
||||||
|
logger.debug("worker for device %s has died", device)
|
||||||
|
|
||||||
for name, thread in self.threads.items():
|
for name, thread in self.threads.items():
|
||||||
logger.debug("stopping worker %s for thread %s", thread.ident, name)
|
logger.debug("stopping worker %s for thread %s", thread.ident, name)
|
||||||
thread.join(self.join_timeout)
|
thread.join(self.join_timeout)
|
||||||
|
|
||||||
logger.debug("worker pool stopped")
|
logger.debug("worker pool stopped")
|
||||||
|
|
||||||
def join_leaking(self):
|
def join_leaking(self):
|
||||||
if len(self.leaking) > 0:
|
if len(self.leaking) > 0:
|
||||||
|
@ -272,56 +275,62 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
def recycle(self):
|
def recycle(self):
|
||||||
logger.debug("recycling worker pool")
|
logger.debug("recycling worker pool")
|
||||||
self.join_leaking()
|
|
||||||
|
|
||||||
needs_restart = []
|
with self.rlock:
|
||||||
|
self.join_leaking()
|
||||||
|
|
||||||
for device, worker in self.workers.items():
|
needs_restart = []
|
||||||
jobs = self.total_jobs.get(device, 0)
|
|
||||||
if not worker.is_alive():
|
for device, worker in self.workers.items():
|
||||||
logger.warning("worker for device %s has died", device)
|
jobs = self.total_jobs.get(device, 0)
|
||||||
needs_restart.append(device)
|
if not worker.is_alive():
|
||||||
elif jobs > self.max_jobs_per_worker:
|
logger.warning("worker for device %s has died", device)
|
||||||
logger.info(
|
needs_restart.append(device)
|
||||||
"shutting down worker for device %s after %s jobs", device, jobs
|
elif jobs > self.max_jobs_per_worker:
|
||||||
)
|
logger.info(
|
||||||
worker.join(self.join_timeout)
|
"shutting down worker for device %s after %s jobs", device, jobs
|
||||||
if worker.is_alive():
|
)
|
||||||
logger.warning(
|
worker.join(self.join_timeout)
|
||||||
"worker %s for device %s could not be recycled in time",
|
if worker.is_alive():
|
||||||
|
logger.warning(
|
||||||
|
"worker %s for device %s could not be recycled in time",
|
||||||
|
worker.pid,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
self.leaking.append((device, worker))
|
||||||
|
else:
|
||||||
|
del worker
|
||||||
|
|
||||||
|
self.workers[device] = None
|
||||||
|
needs_restart.append(device)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"worker %s for device %s does not need to be recycled",
|
||||||
worker.pid,
|
worker.pid,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
self.leaking.append((device, worker))
|
|
||||||
else:
|
|
||||||
del worker
|
|
||||||
|
|
||||||
self.workers[device] = None
|
if len(needs_restart) > 0:
|
||||||
needs_restart.append(device)
|
logger.info("starting new workers")
|
||||||
|
|
||||||
|
for device in self.devices:
|
||||||
|
if device.device in needs_restart:
|
||||||
|
self.create_device_worker(device)
|
||||||
|
self.total_jobs[device.device] = 0
|
||||||
|
|
||||||
|
if self.threads["logger"].is_alive():
|
||||||
|
logger.debug("logger worker is running")
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.warning("restarting logger worker")
|
||||||
"worker %s for device %s does not need to be recycled",
|
self.create_logger_worker()
|
||||||
worker.pid,
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(needs_restart) > 0:
|
if self.threads["progress"].is_alive():
|
||||||
logger.debug("starting new workers")
|
logger.debug("progress worker is running")
|
||||||
|
else:
|
||||||
|
logger.warning("restarting progress worker")
|
||||||
|
self.create_progress_worker()
|
||||||
|
|
||||||
for device in self.devices:
|
logger.debug("worker pool recycled")
|
||||||
if device.device in needs_restart:
|
|
||||||
self.create_device_worker(device)
|
|
||||||
self.total_jobs[device.device] = 0
|
|
||||||
|
|
||||||
if not self.threads["logger"].is_alive():
|
|
||||||
logger.warning("restarting crashed logger worker")
|
|
||||||
self.create_logger_worker()
|
|
||||||
|
|
||||||
if not self.threads["progress"].is_alive():
|
|
||||||
logger.warning("restarting crashed progress worker")
|
|
||||||
self.create_progress_worker()
|
|
||||||
|
|
||||||
logger.debug("worker pool recycled")
|
|
||||||
|
|
||||||
def submit(
|
def submit(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue