1
0
Fork 0

fix(api): use lock when restarting workers

This commit is contained in:
Sean Sube 2023-03-25 09:47:51 -05:00
parent 2c47904057
commit 88f4713e23
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 75 additions and 66 deletions

View File

@ -1,7 +1,7 @@
from collections import Counter from collections import Counter
from logging import getLogger from logging import getLogger
from queue import Empty from queue import Empty
from threading import Thread from threading import Thread, Lock
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Process, Queue, Value from torch.multiprocessing import Process, Queue, Value
@ -37,6 +37,7 @@ class DevicePoolExecutor:
logs: "Queue[str]" logs: "Queue[str]"
progress: "Queue[ProgressCommand]" progress: "Queue[ProgressCommand]"
rlock: Lock
def __init__( def __init__(
self, self,
@ -66,6 +67,7 @@ class DevicePoolExecutor:
self.logs = Queue(self.max_pending_per_worker) self.logs = Queue(self.max_pending_per_worker)
self.progress = Queue(self.max_pending_per_worker) self.progress = Queue(self.max_pending_per_worker)
self.rlock = Lock()
# TODO: these should be part of a start method # TODO: these should be part of a start method
self.create_logger_worker() self.create_logger_worker()
@ -223,35 +225,36 @@ class DevicePoolExecutor:
def join(self): def join(self):
logger.info("stopping worker pool") logger.info("stopping worker pool")
logger.debug("closing queues") with self.rlock:
self.logs.close() logger.debug("closing queues")
self.progress.close() self.logs.close()
for queue in self.pending.values(): self.progress.close()
queue.close() for queue in self.pending.values():
queue.close()
self.pending.clear() self.pending.clear()
self.join_leaking() self.join_leaking()
logger.debug("stopping device workers") logger.debug("stopping device workers")
for device, worker in self.workers.items(): for device, worker in self.workers.items():
if worker.is_alive():
logger.debug("stopping worker %s for device %s", worker.pid, device)
worker.join(self.join_timeout)
if worker.is_alive(): if worker.is_alive():
logger.warning( logger.debug("stopping worker %s for device %s", worker.pid, device)
"worker %s for device %s could not be stopped in time", worker.join(self.join_timeout)
worker.pid, if worker.is_alive():
device, logger.warning(
) "worker %s for device %s could not be stopped in time",
self.leaking.append((device, worker)) worker.pid,
else: device,
logger.debug("worker for device %s has died", device) )
self.leaking.append((device, worker))
else:
logger.debug("worker for device %s has died", device)
for name, thread in self.threads.items(): for name, thread in self.threads.items():
logger.debug("stopping worker %s for thread %s", thread.ident, name) logger.debug("stopping worker %s for thread %s", thread.ident, name)
thread.join(self.join_timeout) thread.join(self.join_timeout)
logger.debug("worker pool stopped") logger.debug("worker pool stopped")
def join_leaking(self): def join_leaking(self):
if len(self.leaking) > 0: if len(self.leaking) > 0:
@ -272,56 +275,62 @@ class DevicePoolExecutor:
def recycle(self): def recycle(self):
logger.debug("recycling worker pool") logger.debug("recycling worker pool")
self.join_leaking()
needs_restart = [] with self.rlock:
self.join_leaking()
for device, worker in self.workers.items(): needs_restart = []
jobs = self.total_jobs.get(device, 0)
if not worker.is_alive(): for device, worker in self.workers.items():
logger.warning("worker for device %s has died", device) jobs = self.total_jobs.get(device, 0)
needs_restart.append(device) if not worker.is_alive():
elif jobs > self.max_jobs_per_worker: logger.warning("worker for device %s has died", device)
logger.info( needs_restart.append(device)
"shutting down worker for device %s after %s jobs", device, jobs elif jobs > self.max_jobs_per_worker:
) logger.info(
worker.join(self.join_timeout) "shutting down worker for device %s after %s jobs", device, jobs
if worker.is_alive(): )
logger.warning( worker.join(self.join_timeout)
"worker %s for device %s could not be recycled in time", if worker.is_alive():
logger.warning(
"worker %s for device %s could not be recycled in time",
worker.pid,
device,
)
self.leaking.append((device, worker))
else:
del worker
self.workers[device] = None
needs_restart.append(device)
else:
logger.debug(
"worker %s for device %s does not need to be recycled",
worker.pid, worker.pid,
device, device,
) )
self.leaking.append((device, worker))
else:
del worker
self.workers[device] = None if len(needs_restart) > 0:
needs_restart.append(device) logger.info("starting new workers")
for device in self.devices:
if device.device in needs_restart:
self.create_device_worker(device)
self.total_jobs[device.device] = 0
if self.threads["logger"].is_alive():
logger.debug("logger worker is running")
else: else:
logger.debug( logger.warning("restarting logger worker")
"worker %s for device %s does not need to be recycled", self.create_logger_worker()
worker.pid,
device,
)
if len(needs_restart) > 0: if self.threads["progress"].is_alive():
logger.debug("starting new workers") logger.debug("progress worker is running")
else:
logger.warning("restarting progress worker")
self.create_progress_worker()
for device in self.devices: logger.debug("worker pool recycled")
if device.device in needs_restart:
self.create_device_worker(device)
self.total_jobs[device.device] = 0
if not self.threads["logger"].is_alive():
logger.warning("restarting crashed logger worker")
self.create_logger_worker()
if not self.threads["progress"].is_alive():
logger.warning("restarting crashed progress worker")
self.create_progress_worker()
logger.debug("worker pool recycled")
def submit( def submit(
self, self,