feat(api): check device worker pool and recycle on a regular interval (#284)
This commit is contained in:
parent
aeb71ad50a
commit
e552a5560f
|
@ -27,7 +27,7 @@ def buffer_external_data_tensors(
|
|||
for tensor in model.graph.initializer:
|
||||
name = tensor.name
|
||||
|
||||
logger.debug("externalizing tensor: %s", name)
|
||||
logger.trace("externalizing tensor: %s", name)
|
||||
if tensor.HasField("raw_data"):
|
||||
npt = numpy_helper.to_array(tensor)
|
||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
|
@ -74,7 +74,7 @@ def blend_loras(
|
|||
|
||||
blended: Dict[str, np.ndarray] = {}
|
||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||
if lora_model is None:
|
||||
logger.warning("unable to load tensor for LoRA")
|
||||
continue
|
||||
|
|
|
@ -10,6 +10,7 @@ from ..params import DeviceParams
|
|||
from ..server import ServerContext
|
||||
from .command import JobCommand, ProgressCommand
|
||||
from .context import WorkerContext
|
||||
from .utils import Interval
|
||||
from .worker import worker_main
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -18,17 +19,24 @@ logger = getLogger(__name__)
|
|||
class DevicePoolExecutor:
|
||||
server: ServerContext
|
||||
devices: List[DeviceParams]
|
||||
|
||||
join_timeout: float
|
||||
max_jobs_per_worker: int
|
||||
max_pending_per_worker: int
|
||||
join_timeout: float
|
||||
progress_interval: float
|
||||
recycle_interval: float
|
||||
|
||||
leaking: List[Tuple[str, Process]]
|
||||
context: Dict[str, WorkerContext] # Device -> Context
|
||||
current: Dict[str, "Value[int]"] # Device -> pid
|
||||
pending: Dict[str, "Queue[JobCommand]"]
|
||||
threads: Dict[str, Thread]
|
||||
progress: Dict[str, "Queue[ProgressCommand]"]
|
||||
workers: Dict[str, Process]
|
||||
|
||||
health_worker: Interval
|
||||
logger_worker: Thread
|
||||
progress_worker: Interval
|
||||
|
||||
cancelled_jobs: List[str]
|
||||
finished_jobs: List[ProgressCommand]
|
||||
pending_jobs: List[JobCommand]
|
||||
|
@ -36,8 +44,7 @@ class DevicePoolExecutor:
|
|||
total_jobs: Dict[str, int] # Device -> job count
|
||||
|
||||
logs: "Queue[str]"
|
||||
progress: "Queue[ProgressCommand]"
|
||||
recycle: Lock
|
||||
rlock: Lock
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -45,18 +52,23 @@ class DevicePoolExecutor:
|
|||
devices: List[DeviceParams],
|
||||
max_pending_per_worker: int = 100,
|
||||
join_timeout: float = 1.0,
|
||||
recycle_interval: float = 10,
|
||||
progress_interval: float = 1.0,
|
||||
):
|
||||
self.server = server
|
||||
self.devices = devices
|
||||
|
||||
self.join_timeout = join_timeout
|
||||
self.max_jobs_per_worker = server.job_limit
|
||||
self.max_pending_per_worker = max_pending_per_worker
|
||||
self.join_timeout = join_timeout
|
||||
self.progress_interval = progress_interval
|
||||
self.recycle_interval = recycle_interval
|
||||
|
||||
self.leaking = []
|
||||
self.context = {}
|
||||
self.current = {}
|
||||
self.pending = {}
|
||||
self.threads = {}
|
||||
self.progress = {}
|
||||
self.workers = {}
|
||||
|
||||
self.cancelled_jobs = []
|
||||
|
@ -66,10 +78,10 @@ class DevicePoolExecutor:
|
|||
self.total_jobs = {}
|
||||
|
||||
self.logs = Queue(self.max_pending_per_worker)
|
||||
self.progress = Queue(self.max_pending_per_worker)
|
||||
self.recycle = Lock()
|
||||
self.rlock = Lock()
|
||||
|
||||
# TODO: these should be part of a start method
|
||||
self.create_health_worker()
|
||||
self.create_logger_worker()
|
||||
self.create_progress_worker()
|
||||
|
||||
|
@ -79,14 +91,9 @@ class DevicePoolExecutor:
|
|||
def create_device_worker(self, device: DeviceParams) -> None:
|
||||
name = device.device
|
||||
|
||||
# reuse the queue if possible, to keep queued jobs
|
||||
if name in self.pending:
|
||||
logger.debug("using existing pending job queue")
|
||||
pending = self.pending[name]
|
||||
else:
|
||||
logger.debug("creating new pending job queue")
|
||||
pending = Queue(self.max_pending_per_worker)
|
||||
self.pending[name] = pending
|
||||
# always recreate queues
|
||||
self.progress[name] = Queue(self.max_pending_per_worker)
|
||||
self.pending[name] = Queue(self.max_pending_per_worker)
|
||||
|
||||
if name in self.current:
|
||||
logger.debug("using existing current worker value")
|
||||
|
@ -100,9 +107,9 @@ class DevicePoolExecutor:
|
|||
name,
|
||||
device,
|
||||
cancel=Value("B", False),
|
||||
progress=self.progress,
|
||||
progress=self.progress[name],
|
||||
logs=self.logs,
|
||||
pending=pending,
|
||||
pending=self.pending[name],
|
||||
active_pid=current,
|
||||
)
|
||||
self.context[name] = context
|
||||
|
@ -117,8 +124,15 @@ class DevicePoolExecutor:
|
|||
self.workers[name] = worker
|
||||
current.value = worker.pid
|
||||
|
||||
def create_health_worker(self) -> None:
|
||||
self.health_worker = Interval(self.recycle_interval, health_main, args=(self,))
|
||||
self.health_worker.daemon = True
|
||||
self.health_worker.name = "onnx-web health"
|
||||
logger.debug("starting health worker")
|
||||
self.health_worker.start()
|
||||
|
||||
def create_logger_worker(self) -> None:
|
||||
logger_thread = Thread(
|
||||
self.logger_worker = Thread(
|
||||
name="onnx-web logger",
|
||||
target=logger_main,
|
||||
args=(
|
||||
|
@ -127,25 +141,18 @@ class DevicePoolExecutor:
|
|||
),
|
||||
daemon=True,
|
||||
)
|
||||
self.threads["logger"] = logger_thread
|
||||
|
||||
logger.debug("starting logger worker")
|
||||
logger_thread.start()
|
||||
self.logger_worker.start()
|
||||
|
||||
def create_progress_worker(self) -> None:
|
||||
progress_thread = Thread(
|
||||
name="onnx-web progress",
|
||||
target=progress_main,
|
||||
args=(
|
||||
self,
|
||||
self.progress,
|
||||
),
|
||||
daemon=True,
|
||||
self.progress_worker = Interval(
|
||||
self.progress_interval, progress_main, args=(self, self.progress)
|
||||
)
|
||||
self.threads["progress"] = progress_thread
|
||||
|
||||
self.progress_worker.daemon = True
|
||||
self.progress_worker.name = "onnx-web progress"
|
||||
logger.debug("starting progress worker")
|
||||
progress_thread.start()
|
||||
self.progress_worker.start()
|
||||
|
||||
def get_job_context(self, key: str) -> WorkerContext:
|
||||
device, _progress = self.running_jobs[key]
|
||||
|
@ -225,10 +232,11 @@ class DevicePoolExecutor:
|
|||
def join(self):
|
||||
logger.info("stopping worker pool")
|
||||
|
||||
with self.recycle:
|
||||
with self.rlock:
|
||||
logger.debug("closing queues")
|
||||
self.logs.close()
|
||||
self.progress.close()
|
||||
for queue in self.progress.values():
|
||||
queue.close()
|
||||
for queue in self.pending.values():
|
||||
queue.close()
|
||||
|
||||
|
@ -250,9 +258,14 @@ class DevicePoolExecutor:
|
|||
else:
|
||||
logger.debug("worker for device %s has died", device)
|
||||
|
||||
for name, thread in self.threads.items():
|
||||
logger.debug("stopping worker %s for thread %s", thread.ident, name)
|
||||
thread.join(self.join_timeout)
|
||||
logger.debug("stopping progress worker")
|
||||
self.progress_worker.join(self.join_timeout)
|
||||
|
||||
logger.debug("stopping logger worker")
|
||||
self.logger_worker.join(self.join_timeout)
|
||||
|
||||
logger.debug("stopping health worker")
|
||||
self.health_worker.join(self.join_timeout)
|
||||
|
||||
logger.debug("worker pool stopped")
|
||||
|
||||
|
@ -276,7 +289,7 @@ class DevicePoolExecutor:
|
|||
def recycle(self):
|
||||
logger.debug("recycling worker pool")
|
||||
|
||||
with self.recycle:
|
||||
with self.rlock:
|
||||
self.join_leaking()
|
||||
|
||||
needs_restart = []
|
||||
|
@ -318,22 +331,14 @@ class DevicePoolExecutor:
|
|||
self.create_device_worker(device)
|
||||
self.total_jobs[device.device] = 0
|
||||
|
||||
if self.threads["logger"].is_alive():
|
||||
if self.logger_worker.is_alive():
|
||||
logger.debug("logger worker is running")
|
||||
if self.logs.full():
|
||||
logger.warning("logger queue is full, restarting worker")
|
||||
self.threads["logger"].join(self.join_timeout)
|
||||
self.create_logger_worker()
|
||||
else:
|
||||
logger.warning("restarting logger worker")
|
||||
self.create_logger_worker()
|
||||
|
||||
if self.threads["progress"].is_alive():
|
||||
if self.progress_worker.is_alive():
|
||||
logger.debug("progress worker is running")
|
||||
if self.progress.full():
|
||||
logger.warning("progress queue is full, restarting worker")
|
||||
self.threads["progress"].join(self.join_timeout)
|
||||
self.create_progress_worker()
|
||||
else:
|
||||
logger.warning("restarting progress worker")
|
||||
self.create_progress_worker()
|
||||
|
@ -366,7 +371,7 @@ class DevicePoolExecutor:
|
|||
|
||||
# recycle before attempting to run
|
||||
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
||||
self.recycle()
|
||||
self.rlock()
|
||||
|
||||
# build and queue job
|
||||
job = JobCommand(key, device, fn, args, kwargs)
|
||||
|
@ -443,15 +448,28 @@ class DevicePoolExecutor:
|
|||
self.context[progress.device].set_cancel()
|
||||
|
||||
|
||||
def logger_main(pool: DevicePoolExecutor, logs: Queue):
|
||||
def health_main(pool: DevicePoolExecutor):
|
||||
logger.trace("checking in from health worker thread")
|
||||
pool.recycle()
|
||||
|
||||
if pool.logs.full():
|
||||
logger.warning("logger queue is full, restarting worker")
|
||||
pool.logger_worker.join(pool.join_timeout)
|
||||
pool.create_logger_worker()
|
||||
|
||||
if any([queue.full() for queue in pool.progress.values()]):
|
||||
logger.warning("progress queue is full, restarting worker")
|
||||
pool.progress_worker.join(pool.join_timeout)
|
||||
pool.create_progress_worker()
|
||||
|
||||
|
||||
def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"):
|
||||
logger.trace("checking in from logger worker thread")
|
||||
|
||||
while True:
|
||||
try:
|
||||
job = logs.get(timeout=(pool.join_timeout / 2))
|
||||
with open("worker.log", "w") as f:
|
||||
logger.info("got log: %s", job)
|
||||
f.write(str(job) + "\n\n")
|
||||
msg = logs.get(pool.join_timeout / 2)
|
||||
logger.debug("received logs from worker: %s", msg)
|
||||
except Empty:
|
||||
pass
|
||||
except ValueError:
|
||||
|
@ -460,15 +478,21 @@ def logger_main(pool: DevicePoolExecutor, logs: Queue):
|
|||
logger.exception("error in log worker")
|
||||
|
||||
|
||||
def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"):
|
||||
def progress_main(
|
||||
pool: DevicePoolExecutor, queues: Dict[str, "Queue[ProgressCommand]"]
|
||||
):
|
||||
logger.trace("checking in from progress worker thread")
|
||||
while True:
|
||||
for device, queue in queues.items():
|
||||
try:
|
||||
progress = queue.get(timeout=(pool.join_timeout / 2))
|
||||
pool.update_job(progress)
|
||||
progress = queue.get_nowait()
|
||||
while progress is not None:
|
||||
pool.update_job(progress)
|
||||
progress = queue.get_nowait()
|
||||
except Empty:
|
||||
logger.trace("empty queue in progress worker for device %s", device)
|
||||
pass
|
||||
except ValueError:
|
||||
except ValueError as e:
|
||||
logger.debug("value error in progress worker for device %s: %s", device, e)
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("error in progress worker")
|
||||
logger.exception("error in progress worker for device %s", device)
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from threading import Timer
|
||||
|
||||
|
||||
class Interval(Timer):
|
||||
"""
|
||||
From https://stackoverflow.com/a/48741004
|
||||
"""
|
||||
|
||||
def run(self):
|
||||
while not self.finished.wait(self.interval):
|
||||
self.function(*self.args, **self.kwargs)
|
Loading…
Reference in New Issue