1
0
Fork 0

feat(api): check device worker pool and recycle on a regular interval (#284)

This commit is contained in:
Sean Sube 2023-03-26 11:09:13 -05:00
parent aeb71ad50a
commit e552a5560f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 97 additions and 62 deletions

View File

@ -27,7 +27,7 @@ def buffer_external_data_tensors(
for tensor in model.graph.initializer:
name = tensor.name
logger.debug("externalizing tensor: %s", name)
logger.trace("externalizing tensor: %s", name)
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
@ -74,7 +74,7 @@ def blend_loras(
blended: Dict[str, np.ndarray] = {}
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None:
logger.warning("unable to load tensor for LoRA")
continue

View File

@ -10,6 +10,7 @@ from ..params import DeviceParams
from ..server import ServerContext
from .command import JobCommand, ProgressCommand
from .context import WorkerContext
from .utils import Interval
from .worker import worker_main
logger = getLogger(__name__)
@ -18,17 +19,24 @@ logger = getLogger(__name__)
class DevicePoolExecutor:
server: ServerContext
devices: List[DeviceParams]
join_timeout: float
max_jobs_per_worker: int
max_pending_per_worker: int
join_timeout: float
progress_interval: float
recycle_interval: float
leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"] # Device -> pid
pending: Dict[str, "Queue[JobCommand]"]
threads: Dict[str, Thread]
progress: Dict[str, "Queue[ProgressCommand]"]
workers: Dict[str, Process]
health_worker: Interval
logger_worker: Thread
progress_worker: Interval
cancelled_jobs: List[str]
finished_jobs: List[ProgressCommand]
pending_jobs: List[JobCommand]
@ -36,8 +44,7 @@ class DevicePoolExecutor:
total_jobs: Dict[str, int] # Device -> job count
logs: "Queue[str]"
progress: "Queue[ProgressCommand]"
recycle: Lock
rlock: Lock
def __init__(
self,
@ -45,18 +52,23 @@ class DevicePoolExecutor:
devices: List[DeviceParams],
max_pending_per_worker: int = 100,
join_timeout: float = 1.0,
recycle_interval: float = 10,
progress_interval: float = 1.0,
):
self.server = server
self.devices = devices
self.join_timeout = join_timeout
self.max_jobs_per_worker = server.job_limit
self.max_pending_per_worker = max_pending_per_worker
self.join_timeout = join_timeout
self.progress_interval = progress_interval
self.recycle_interval = recycle_interval
self.leaking = []
self.context = {}
self.current = {}
self.pending = {}
self.threads = {}
self.progress = {}
self.workers = {}
self.cancelled_jobs = []
@ -66,10 +78,10 @@ class DevicePoolExecutor:
self.total_jobs = {}
self.logs = Queue(self.max_pending_per_worker)
self.progress = Queue(self.max_pending_per_worker)
self.recycle = Lock()
self.rlock = Lock()
# TODO: these should be part of a start method
self.create_health_worker()
self.create_logger_worker()
self.create_progress_worker()
@ -79,14 +91,9 @@ class DevicePoolExecutor:
def create_device_worker(self, device: DeviceParams) -> None:
name = device.device
# reuse the queue if possible, to keep queued jobs
if name in self.pending:
logger.debug("using existing pending job queue")
pending = self.pending[name]
else:
logger.debug("creating new pending job queue")
pending = Queue(self.max_pending_per_worker)
self.pending[name] = pending
# always recreate queues
self.progress[name] = Queue(self.max_pending_per_worker)
self.pending[name] = Queue(self.max_pending_per_worker)
if name in self.current:
logger.debug("using existing current worker value")
@ -100,9 +107,9 @@ class DevicePoolExecutor:
name,
device,
cancel=Value("B", False),
progress=self.progress,
progress=self.progress[name],
logs=self.logs,
pending=pending,
pending=self.pending[name],
active_pid=current,
)
self.context[name] = context
@ -117,8 +124,15 @@ class DevicePoolExecutor:
self.workers[name] = worker
current.value = worker.pid
def create_health_worker(self) -> None:
self.health_worker = Interval(self.recycle_interval, health_main, args=(self,))
self.health_worker.daemon = True
self.health_worker.name = "onnx-web health"
logger.debug("starting health worker")
self.health_worker.start()
def create_logger_worker(self) -> None:
logger_thread = Thread(
self.logger_worker = Thread(
name="onnx-web logger",
target=logger_main,
args=(
@ -127,25 +141,18 @@ class DevicePoolExecutor:
),
daemon=True,
)
self.threads["logger"] = logger_thread
logger.debug("starting logger worker")
logger_thread.start()
self.logger_worker.start()
def create_progress_worker(self) -> None:
progress_thread = Thread(
name="onnx-web progress",
target=progress_main,
args=(
self,
self.progress,
),
daemon=True,
self.progress_worker = Interval(
self.progress_interval, progress_main, args=(self, self.progress)
)
self.threads["progress"] = progress_thread
self.progress_worker.daemon = True
self.progress_worker.name = "onnx-web progress"
logger.debug("starting progress worker")
progress_thread.start()
self.progress_worker.start()
def get_job_context(self, key: str) -> WorkerContext:
device, _progress = self.running_jobs[key]
@ -225,10 +232,11 @@ class DevicePoolExecutor:
def join(self):
logger.info("stopping worker pool")
with self.recycle:
with self.rlock:
logger.debug("closing queues")
self.logs.close()
self.progress.close()
for queue in self.progress.values():
queue.close()
for queue in self.pending.values():
queue.close()
@ -250,9 +258,14 @@ class DevicePoolExecutor:
else:
logger.debug("worker for device %s has died", device)
for name, thread in self.threads.items():
logger.debug("stopping worker %s for thread %s", thread.ident, name)
thread.join(self.join_timeout)
logger.debug("stopping progress worker")
self.progress_worker.join(self.join_timeout)
logger.debug("stopping logger worker")
self.logger_worker.join(self.join_timeout)
logger.debug("stopping health worker")
self.health_worker.join(self.join_timeout)
logger.debug("worker pool stopped")
@ -276,7 +289,7 @@ class DevicePoolExecutor:
def recycle(self):
logger.debug("recycling worker pool")
with self.recycle:
with self.rlock:
self.join_leaking()
needs_restart = []
@ -318,22 +331,14 @@ class DevicePoolExecutor:
self.create_device_worker(device)
self.total_jobs[device.device] = 0
if self.threads["logger"].is_alive():
if self.logger_worker.is_alive():
logger.debug("logger worker is running")
if self.logs.full():
logger.warning("logger queue is full, restarting worker")
self.threads["logger"].join(self.join_timeout)
self.create_logger_worker()
else:
logger.warning("restarting logger worker")
self.create_logger_worker()
if self.threads["progress"].is_alive():
if self.progress_worker.is_alive():
logger.debug("progress worker is running")
if self.progress.full():
logger.warning("progress queue is full, restarting worker")
self.threads["progress"].join(self.join_timeout)
self.create_progress_worker()
else:
logger.warning("restarting progress worker")
self.create_progress_worker()
@ -366,7 +371,7 @@ class DevicePoolExecutor:
# recycle before attempting to run
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
self.recycle()
self.rlock()
# build and queue job
job = JobCommand(key, device, fn, args, kwargs)
@ -443,15 +448,28 @@ class DevicePoolExecutor:
self.context[progress.device].set_cancel()
def logger_main(pool: DevicePoolExecutor, logs: Queue):
def health_main(pool: DevicePoolExecutor):
logger.trace("checking in from health worker thread")
pool.recycle()
if pool.logs.full():
logger.warning("logger queue is full, restarting worker")
pool.logger_worker.join(pool.join_timeout)
pool.create_logger_worker()
if any([queue.full() for queue in pool.progress.values()]):
logger.warning("progress queue is full, restarting worker")
pool.progress_worker.join(pool.join_timeout)
pool.create_progress_worker()
def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"):
logger.trace("checking in from logger worker thread")
while True:
try:
job = logs.get(timeout=(pool.join_timeout / 2))
with open("worker.log", "w") as f:
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
msg = logs.get(pool.join_timeout / 2)
logger.debug("received logs from worker: %s", msg)
except Empty:
pass
except ValueError:
@ -460,15 +478,21 @@ def logger_main(pool: DevicePoolExecutor, logs: Queue):
logger.exception("error in log worker")
def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"):
def progress_main(
pool: DevicePoolExecutor, queues: Dict[str, "Queue[ProgressCommand]"]
):
logger.trace("checking in from progress worker thread")
while True:
for device, queue in queues.items():
try:
progress = queue.get(timeout=(pool.join_timeout / 2))
progress = queue.get_nowait()
while progress is not None:
pool.update_job(progress)
progress = queue.get_nowait()
except Empty:
logger.trace("empty queue in progress worker for device %s", device)
pass
except ValueError:
except ValueError as e:
logger.debug("value error in progress worker for device %s: %s", device, e)
break
except Exception:
logger.exception("error in progress worker")
logger.exception("error in progress worker for device %s", device)

View File

@ -0,0 +1,11 @@
from threading import Timer
class Interval(Timer):
"""
From https://stackoverflow.com/a/48741004
"""
def run(self):
while not self.finished.wait(self.interval):
self.function(*self.args, **self.kwargs)