feat(api): check device worker pool and recycle on a regular interval (#284)
This commit is contained in:
parent
aeb71ad50a
commit
e552a5560f
|
@ -27,7 +27,7 @@ def buffer_external_data_tensors(
|
||||||
for tensor in model.graph.initializer:
|
for tensor in model.graph.initializer:
|
||||||
name = tensor.name
|
name = tensor.name
|
||||||
|
|
||||||
logger.debug("externalizing tensor: %s", name)
|
logger.trace("externalizing tensor: %s", name)
|
||||||
if tensor.HasField("raw_data"):
|
if tensor.HasField("raw_data"):
|
||||||
npt = numpy_helper.to_array(tensor)
|
npt = numpy_helper.to_array(tensor)
|
||||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
|
@ -74,7 +74,7 @@ def blend_loras(
|
||||||
|
|
||||||
blended: Dict[str, np.ndarray] = {}
|
blended: Dict[str, np.ndarray] = {}
|
||||||
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
|
||||||
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
|
||||||
if lora_model is None:
|
if lora_model is None:
|
||||||
logger.warning("unable to load tensor for LoRA")
|
logger.warning("unable to load tensor for LoRA")
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -10,6 +10,7 @@ from ..params import DeviceParams
|
||||||
from ..server import ServerContext
|
from ..server import ServerContext
|
||||||
from .command import JobCommand, ProgressCommand
|
from .command import JobCommand, ProgressCommand
|
||||||
from .context import WorkerContext
|
from .context import WorkerContext
|
||||||
|
from .utils import Interval
|
||||||
from .worker import worker_main
|
from .worker import worker_main
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -18,17 +19,24 @@ logger = getLogger(__name__)
|
||||||
class DevicePoolExecutor:
|
class DevicePoolExecutor:
|
||||||
server: ServerContext
|
server: ServerContext
|
||||||
devices: List[DeviceParams]
|
devices: List[DeviceParams]
|
||||||
|
|
||||||
|
join_timeout: float
|
||||||
max_jobs_per_worker: int
|
max_jobs_per_worker: int
|
||||||
max_pending_per_worker: int
|
max_pending_per_worker: int
|
||||||
join_timeout: float
|
progress_interval: float
|
||||||
|
recycle_interval: float
|
||||||
|
|
||||||
leaking: List[Tuple[str, Process]]
|
leaking: List[Tuple[str, Process]]
|
||||||
context: Dict[str, WorkerContext] # Device -> Context
|
context: Dict[str, WorkerContext] # Device -> Context
|
||||||
current: Dict[str, "Value[int]"] # Device -> pid
|
current: Dict[str, "Value[int]"] # Device -> pid
|
||||||
pending: Dict[str, "Queue[JobCommand]"]
|
pending: Dict[str, "Queue[JobCommand]"]
|
||||||
threads: Dict[str, Thread]
|
progress: Dict[str, "Queue[ProgressCommand]"]
|
||||||
workers: Dict[str, Process]
|
workers: Dict[str, Process]
|
||||||
|
|
||||||
|
health_worker: Interval
|
||||||
|
logger_worker: Thread
|
||||||
|
progress_worker: Interval
|
||||||
|
|
||||||
cancelled_jobs: List[str]
|
cancelled_jobs: List[str]
|
||||||
finished_jobs: List[ProgressCommand]
|
finished_jobs: List[ProgressCommand]
|
||||||
pending_jobs: List[JobCommand]
|
pending_jobs: List[JobCommand]
|
||||||
|
@ -36,8 +44,7 @@ class DevicePoolExecutor:
|
||||||
total_jobs: Dict[str, int] # Device -> job count
|
total_jobs: Dict[str, int] # Device -> job count
|
||||||
|
|
||||||
logs: "Queue[str]"
|
logs: "Queue[str]"
|
||||||
progress: "Queue[ProgressCommand]"
|
rlock: Lock
|
||||||
recycle: Lock
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -45,18 +52,23 @@ class DevicePoolExecutor:
|
||||||
devices: List[DeviceParams],
|
devices: List[DeviceParams],
|
||||||
max_pending_per_worker: int = 100,
|
max_pending_per_worker: int = 100,
|
||||||
join_timeout: float = 1.0,
|
join_timeout: float = 1.0,
|
||||||
|
recycle_interval: float = 10,
|
||||||
|
progress_interval: float = 1.0,
|
||||||
):
|
):
|
||||||
self.server = server
|
self.server = server
|
||||||
self.devices = devices
|
self.devices = devices
|
||||||
|
|
||||||
|
self.join_timeout = join_timeout
|
||||||
self.max_jobs_per_worker = server.job_limit
|
self.max_jobs_per_worker = server.job_limit
|
||||||
self.max_pending_per_worker = max_pending_per_worker
|
self.max_pending_per_worker = max_pending_per_worker
|
||||||
self.join_timeout = join_timeout
|
self.progress_interval = progress_interval
|
||||||
|
self.recycle_interval = recycle_interval
|
||||||
|
|
||||||
self.leaking = []
|
self.leaking = []
|
||||||
self.context = {}
|
self.context = {}
|
||||||
self.current = {}
|
self.current = {}
|
||||||
self.pending = {}
|
self.pending = {}
|
||||||
self.threads = {}
|
self.progress = {}
|
||||||
self.workers = {}
|
self.workers = {}
|
||||||
|
|
||||||
self.cancelled_jobs = []
|
self.cancelled_jobs = []
|
||||||
|
@ -66,10 +78,10 @@ class DevicePoolExecutor:
|
||||||
self.total_jobs = {}
|
self.total_jobs = {}
|
||||||
|
|
||||||
self.logs = Queue(self.max_pending_per_worker)
|
self.logs = Queue(self.max_pending_per_worker)
|
||||||
self.progress = Queue(self.max_pending_per_worker)
|
self.rlock = Lock()
|
||||||
self.recycle = Lock()
|
|
||||||
|
|
||||||
# TODO: these should be part of a start method
|
# TODO: these should be part of a start method
|
||||||
|
self.create_health_worker()
|
||||||
self.create_logger_worker()
|
self.create_logger_worker()
|
||||||
self.create_progress_worker()
|
self.create_progress_worker()
|
||||||
|
|
||||||
|
@ -79,14 +91,9 @@ class DevicePoolExecutor:
|
||||||
def create_device_worker(self, device: DeviceParams) -> None:
|
def create_device_worker(self, device: DeviceParams) -> None:
|
||||||
name = device.device
|
name = device.device
|
||||||
|
|
||||||
# reuse the queue if possible, to keep queued jobs
|
# always recreate queues
|
||||||
if name in self.pending:
|
self.progress[name] = Queue(self.max_pending_per_worker)
|
||||||
logger.debug("using existing pending job queue")
|
self.pending[name] = Queue(self.max_pending_per_worker)
|
||||||
pending = self.pending[name]
|
|
||||||
else:
|
|
||||||
logger.debug("creating new pending job queue")
|
|
||||||
pending = Queue(self.max_pending_per_worker)
|
|
||||||
self.pending[name] = pending
|
|
||||||
|
|
||||||
if name in self.current:
|
if name in self.current:
|
||||||
logger.debug("using existing current worker value")
|
logger.debug("using existing current worker value")
|
||||||
|
@ -100,9 +107,9 @@ class DevicePoolExecutor:
|
||||||
name,
|
name,
|
||||||
device,
|
device,
|
||||||
cancel=Value("B", False),
|
cancel=Value("B", False),
|
||||||
progress=self.progress,
|
progress=self.progress[name],
|
||||||
logs=self.logs,
|
logs=self.logs,
|
||||||
pending=pending,
|
pending=self.pending[name],
|
||||||
active_pid=current,
|
active_pid=current,
|
||||||
)
|
)
|
||||||
self.context[name] = context
|
self.context[name] = context
|
||||||
|
@ -117,8 +124,15 @@ class DevicePoolExecutor:
|
||||||
self.workers[name] = worker
|
self.workers[name] = worker
|
||||||
current.value = worker.pid
|
current.value = worker.pid
|
||||||
|
|
||||||
|
def create_health_worker(self) -> None:
|
||||||
|
self.health_worker = Interval(self.recycle_interval, health_main, args=(self,))
|
||||||
|
self.health_worker.daemon = True
|
||||||
|
self.health_worker.name = "onnx-web health"
|
||||||
|
logger.debug("starting health worker")
|
||||||
|
self.health_worker.start()
|
||||||
|
|
||||||
def create_logger_worker(self) -> None:
|
def create_logger_worker(self) -> None:
|
||||||
logger_thread = Thread(
|
self.logger_worker = Thread(
|
||||||
name="onnx-web logger",
|
name="onnx-web logger",
|
||||||
target=logger_main,
|
target=logger_main,
|
||||||
args=(
|
args=(
|
||||||
|
@ -127,25 +141,18 @@ class DevicePoolExecutor:
|
||||||
),
|
),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
self.threads["logger"] = logger_thread
|
|
||||||
|
|
||||||
logger.debug("starting logger worker")
|
logger.debug("starting logger worker")
|
||||||
logger_thread.start()
|
self.logger_worker.start()
|
||||||
|
|
||||||
def create_progress_worker(self) -> None:
|
def create_progress_worker(self) -> None:
|
||||||
progress_thread = Thread(
|
self.progress_worker = Interval(
|
||||||
name="onnx-web progress",
|
self.progress_interval, progress_main, args=(self, self.progress)
|
||||||
target=progress_main,
|
|
||||||
args=(
|
|
||||||
self,
|
|
||||||
self.progress,
|
|
||||||
),
|
|
||||||
daemon=True,
|
|
||||||
)
|
)
|
||||||
self.threads["progress"] = progress_thread
|
self.progress_worker.daemon = True
|
||||||
|
self.progress_worker.name = "onnx-web progress"
|
||||||
logger.debug("starting progress worker")
|
logger.debug("starting progress worker")
|
||||||
progress_thread.start()
|
self.progress_worker.start()
|
||||||
|
|
||||||
def get_job_context(self, key: str) -> WorkerContext:
|
def get_job_context(self, key: str) -> WorkerContext:
|
||||||
device, _progress = self.running_jobs[key]
|
device, _progress = self.running_jobs[key]
|
||||||
|
@ -225,10 +232,11 @@ class DevicePoolExecutor:
|
||||||
def join(self):
|
def join(self):
|
||||||
logger.info("stopping worker pool")
|
logger.info("stopping worker pool")
|
||||||
|
|
||||||
with self.recycle:
|
with self.rlock:
|
||||||
logger.debug("closing queues")
|
logger.debug("closing queues")
|
||||||
self.logs.close()
|
self.logs.close()
|
||||||
self.progress.close()
|
for queue in self.progress.values():
|
||||||
|
queue.close()
|
||||||
for queue in self.pending.values():
|
for queue in self.pending.values():
|
||||||
queue.close()
|
queue.close()
|
||||||
|
|
||||||
|
@ -250,9 +258,14 @@ class DevicePoolExecutor:
|
||||||
else:
|
else:
|
||||||
logger.debug("worker for device %s has died", device)
|
logger.debug("worker for device %s has died", device)
|
||||||
|
|
||||||
for name, thread in self.threads.items():
|
logger.debug("stopping progress worker")
|
||||||
logger.debug("stopping worker %s for thread %s", thread.ident, name)
|
self.progress_worker.join(self.join_timeout)
|
||||||
thread.join(self.join_timeout)
|
|
||||||
|
logger.debug("stopping logger worker")
|
||||||
|
self.logger_worker.join(self.join_timeout)
|
||||||
|
|
||||||
|
logger.debug("stopping health worker")
|
||||||
|
self.health_worker.join(self.join_timeout)
|
||||||
|
|
||||||
logger.debug("worker pool stopped")
|
logger.debug("worker pool stopped")
|
||||||
|
|
||||||
|
@ -276,7 +289,7 @@ class DevicePoolExecutor:
|
||||||
def recycle(self):
|
def recycle(self):
|
||||||
logger.debug("recycling worker pool")
|
logger.debug("recycling worker pool")
|
||||||
|
|
||||||
with self.recycle:
|
with self.rlock:
|
||||||
self.join_leaking()
|
self.join_leaking()
|
||||||
|
|
||||||
needs_restart = []
|
needs_restart = []
|
||||||
|
@ -318,22 +331,14 @@ class DevicePoolExecutor:
|
||||||
self.create_device_worker(device)
|
self.create_device_worker(device)
|
||||||
self.total_jobs[device.device] = 0
|
self.total_jobs[device.device] = 0
|
||||||
|
|
||||||
if self.threads["logger"].is_alive():
|
if self.logger_worker.is_alive():
|
||||||
logger.debug("logger worker is running")
|
logger.debug("logger worker is running")
|
||||||
if self.logs.full():
|
|
||||||
logger.warning("logger queue is full, restarting worker")
|
|
||||||
self.threads["logger"].join(self.join_timeout)
|
|
||||||
self.create_logger_worker()
|
|
||||||
else:
|
else:
|
||||||
logger.warning("restarting logger worker")
|
logger.warning("restarting logger worker")
|
||||||
self.create_logger_worker()
|
self.create_logger_worker()
|
||||||
|
|
||||||
if self.threads["progress"].is_alive():
|
if self.progress_worker.is_alive():
|
||||||
logger.debug("progress worker is running")
|
logger.debug("progress worker is running")
|
||||||
if self.progress.full():
|
|
||||||
logger.warning("progress queue is full, restarting worker")
|
|
||||||
self.threads["progress"].join(self.join_timeout)
|
|
||||||
self.create_progress_worker()
|
|
||||||
else:
|
else:
|
||||||
logger.warning("restarting progress worker")
|
logger.warning("restarting progress worker")
|
||||||
self.create_progress_worker()
|
self.create_progress_worker()
|
||||||
|
@ -366,7 +371,7 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
# recycle before attempting to run
|
# recycle before attempting to run
|
||||||
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
||||||
self.recycle()
|
self.rlock()
|
||||||
|
|
||||||
# build and queue job
|
# build and queue job
|
||||||
job = JobCommand(key, device, fn, args, kwargs)
|
job = JobCommand(key, device, fn, args, kwargs)
|
||||||
|
@ -443,15 +448,28 @@ class DevicePoolExecutor:
|
||||||
self.context[progress.device].set_cancel()
|
self.context[progress.device].set_cancel()
|
||||||
|
|
||||||
|
|
||||||
def logger_main(pool: DevicePoolExecutor, logs: Queue):
|
def health_main(pool: DevicePoolExecutor):
|
||||||
|
logger.trace("checking in from health worker thread")
|
||||||
|
pool.recycle()
|
||||||
|
|
||||||
|
if pool.logs.full():
|
||||||
|
logger.warning("logger queue is full, restarting worker")
|
||||||
|
pool.logger_worker.join(pool.join_timeout)
|
||||||
|
pool.create_logger_worker()
|
||||||
|
|
||||||
|
if any([queue.full() for queue in pool.progress.values()]):
|
||||||
|
logger.warning("progress queue is full, restarting worker")
|
||||||
|
pool.progress_worker.join(pool.join_timeout)
|
||||||
|
pool.create_progress_worker()
|
||||||
|
|
||||||
|
|
||||||
|
def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"):
|
||||||
logger.trace("checking in from logger worker thread")
|
logger.trace("checking in from logger worker thread")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
job = logs.get(timeout=(pool.join_timeout / 2))
|
msg = logs.get(pool.join_timeout / 2)
|
||||||
with open("worker.log", "w") as f:
|
logger.debug("received logs from worker: %s", msg)
|
||||||
logger.info("got log: %s", job)
|
|
||||||
f.write(str(job) + "\n\n")
|
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -460,15 +478,21 @@ def logger_main(pool: DevicePoolExecutor, logs: Queue):
|
||||||
logger.exception("error in log worker")
|
logger.exception("error in log worker")
|
||||||
|
|
||||||
|
|
||||||
def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"):
|
def progress_main(
|
||||||
|
pool: DevicePoolExecutor, queues: Dict[str, "Queue[ProgressCommand]"]
|
||||||
|
):
|
||||||
logger.trace("checking in from progress worker thread")
|
logger.trace("checking in from progress worker thread")
|
||||||
while True:
|
for device, queue in queues.items():
|
||||||
try:
|
try:
|
||||||
progress = queue.get(timeout=(pool.join_timeout / 2))
|
progress = queue.get_nowait()
|
||||||
pool.update_job(progress)
|
while progress is not None:
|
||||||
|
pool.update_job(progress)
|
||||||
|
progress = queue.get_nowait()
|
||||||
except Empty:
|
except Empty:
|
||||||
|
logger.trace("empty queue in progress worker for device %s", device)
|
||||||
pass
|
pass
|
||||||
except ValueError:
|
except ValueError as e:
|
||||||
|
logger.debug("value error in progress worker for device %s: %s", device, e)
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error in progress worker")
|
logger.exception("error in progress worker for device %s", device)
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
from threading import Timer
|
||||||
|
|
||||||
|
|
||||||
|
class Interval(Timer):
|
||||||
|
"""
|
||||||
|
From https://stackoverflow.com/a/48741004
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
while not self.finished.wait(self.interval):
|
||||||
|
self.function(*self.args, **self.kwargs)
|
Loading…
Reference in New Issue