1
0
Fork 0

fix(api): use lock when restarting workers

This commit is contained in:
Sean Sube 2023-03-25 09:47:51 -05:00
parent 2c47904057
commit 88f4713e23
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 75 additions and 66 deletions

View File

@ -1,7 +1,7 @@
from collections import Counter
from logging import getLogger
from queue import Empty
from threading import Thread
from threading import Thread, Lock
from typing import Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Process, Queue, Value
@ -37,6 +37,7 @@ class DevicePoolExecutor:
logs: "Queue[str]"
progress: "Queue[ProgressCommand]"
rlock: Lock
def __init__(
self,
@ -66,6 +67,7 @@ class DevicePoolExecutor:
self.logs = Queue(self.max_pending_per_worker)
self.progress = Queue(self.max_pending_per_worker)
self.rlock = Lock()
# TODO: these should be part of a start method
self.create_logger_worker()
@ -223,6 +225,7 @@ class DevicePoolExecutor:
def join(self):
logger.info("stopping worker pool")
with self.rlock:
logger.debug("closing queues")
self.logs.close()
self.progress.close()
@ -272,6 +275,8 @@ class DevicePoolExecutor:
def recycle(self):
logger.debug("recycling worker pool")
with self.rlock:
self.join_leaking()
needs_restart = []
@ -306,19 +311,23 @@ class DevicePoolExecutor:
)
if len(needs_restart) > 0:
logger.debug("starting new workers")
logger.info("starting new workers")
for device in self.devices:
if device.device in needs_restart:
self.create_device_worker(device)
self.total_jobs[device.device] = 0
if not self.threads["logger"].is_alive():
logger.warning("restarting crashed logger worker")
if self.threads["logger"].is_alive():
logger.debug("logger worker is running")
else:
logger.warning("restarting logger worker")
self.create_logger_worker()
if not self.threads["progress"].is_alive():
logger.warning("restarting crashed progress worker")
if self.threads["progress"].is_alive():
logger.debug("progress worker is running")
else:
logger.warning("restarting progress worker")
self.create_progress_worker()
logger.debug("worker pool recycled")