1
0
Fork 0

feat(api): check device worker pool and recycle on a regular interval (#284)

This commit is contained in:
Sean Sube 2023-03-26 11:09:13 -05:00
parent aeb71ad50a
commit e552a5560f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 97 additions and 62 deletions

View File

@ -27,7 +27,7 @@ def buffer_external_data_tensors(
for tensor in model.graph.initializer: for tensor in model.graph.initializer:
name = tensor.name name = tensor.name
logger.debug("externalizing tensor: %s", name) logger.trace("externalizing tensor: %s", name)
if tensor.HasField("raw_data"): if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor) npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt) orv = OrtValue.ortvalue_from_numpy(npt)
@ -74,7 +74,7 @@ def blend_loras(
blended: Dict[str, np.ndarray] = {} blended: Dict[str, np.ndarray] = {}
for (lora_name, lora_weight), lora_model in zip(loras, lora_models): for (lora_name, lora_weight), lora_model in zip(loras, lora_models):
logger.info("blending LoRA from %s with weight of %s", lora_name, lora_weight) logger.debug("blending LoRA from %s with weight of %s", lora_name, lora_weight)
if lora_model is None: if lora_model is None:
logger.warning("unable to load tensor for LoRA") logger.warning("unable to load tensor for LoRA")
continue continue

View File

@ -10,6 +10,7 @@ from ..params import DeviceParams
from ..server import ServerContext from ..server import ServerContext
from .command import JobCommand, ProgressCommand from .command import JobCommand, ProgressCommand
from .context import WorkerContext from .context import WorkerContext
from .utils import Interval
from .worker import worker_main from .worker import worker_main
logger = getLogger(__name__) logger = getLogger(__name__)
@ -18,17 +19,24 @@ logger = getLogger(__name__)
class DevicePoolExecutor: class DevicePoolExecutor:
server: ServerContext server: ServerContext
devices: List[DeviceParams] devices: List[DeviceParams]
join_timeout: float
max_jobs_per_worker: int max_jobs_per_worker: int
max_pending_per_worker: int max_pending_per_worker: int
join_timeout: float progress_interval: float
recycle_interval: float
leaking: List[Tuple[str, Process]] leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"] # Device -> pid current: Dict[str, "Value[int]"] # Device -> pid
pending: Dict[str, "Queue[JobCommand]"] pending: Dict[str, "Queue[JobCommand]"]
threads: Dict[str, Thread] progress: Dict[str, "Queue[ProgressCommand]"]
workers: Dict[str, Process] workers: Dict[str, Process]
health_worker: Interval
logger_worker: Thread
progress_worker: Interval
cancelled_jobs: List[str] cancelled_jobs: List[str]
finished_jobs: List[ProgressCommand] finished_jobs: List[ProgressCommand]
pending_jobs: List[JobCommand] pending_jobs: List[JobCommand]
@ -36,8 +44,7 @@ class DevicePoolExecutor:
total_jobs: Dict[str, int] # Device -> job count total_jobs: Dict[str, int] # Device -> job count
logs: "Queue[str]" logs: "Queue[str]"
progress: "Queue[ProgressCommand]" rlock: Lock
recycle: Lock
def __init__( def __init__(
self, self,
@ -45,18 +52,23 @@ class DevicePoolExecutor:
devices: List[DeviceParams], devices: List[DeviceParams],
max_pending_per_worker: int = 100, max_pending_per_worker: int = 100,
join_timeout: float = 1.0, join_timeout: float = 1.0,
recycle_interval: float = 10,
progress_interval: float = 1.0,
): ):
self.server = server self.server = server
self.devices = devices self.devices = devices
self.join_timeout = join_timeout
self.max_jobs_per_worker = server.job_limit self.max_jobs_per_worker = server.job_limit
self.max_pending_per_worker = max_pending_per_worker self.max_pending_per_worker = max_pending_per_worker
self.join_timeout = join_timeout self.progress_interval = progress_interval
self.recycle_interval = recycle_interval
self.leaking = [] self.leaking = []
self.context = {} self.context = {}
self.current = {} self.current = {}
self.pending = {} self.pending = {}
self.threads = {} self.progress = {}
self.workers = {} self.workers = {}
self.cancelled_jobs = [] self.cancelled_jobs = []
@ -66,10 +78,10 @@ class DevicePoolExecutor:
self.total_jobs = {} self.total_jobs = {}
self.logs = Queue(self.max_pending_per_worker) self.logs = Queue(self.max_pending_per_worker)
self.progress = Queue(self.max_pending_per_worker) self.rlock = Lock()
self.recycle = Lock()
# TODO: these should be part of a start method # TODO: these should be part of a start method
self.create_health_worker()
self.create_logger_worker() self.create_logger_worker()
self.create_progress_worker() self.create_progress_worker()
@ -79,14 +91,9 @@ class DevicePoolExecutor:
def create_device_worker(self, device: DeviceParams) -> None: def create_device_worker(self, device: DeviceParams) -> None:
name = device.device name = device.device
# reuse the queue if possible, to keep queued jobs # always recreate queues
if name in self.pending: self.progress[name] = Queue(self.max_pending_per_worker)
logger.debug("using existing pending job queue") self.pending[name] = Queue(self.max_pending_per_worker)
pending = self.pending[name]
else:
logger.debug("creating new pending job queue")
pending = Queue(self.max_pending_per_worker)
self.pending[name] = pending
if name in self.current: if name in self.current:
logger.debug("using existing current worker value") logger.debug("using existing current worker value")
@ -100,9 +107,9 @@ class DevicePoolExecutor:
name, name,
device, device,
cancel=Value("B", False), cancel=Value("B", False),
progress=self.progress, progress=self.progress[name],
logs=self.logs, logs=self.logs,
pending=pending, pending=self.pending[name],
active_pid=current, active_pid=current,
) )
self.context[name] = context self.context[name] = context
@ -117,8 +124,15 @@ class DevicePoolExecutor:
self.workers[name] = worker self.workers[name] = worker
current.value = worker.pid current.value = worker.pid
def create_health_worker(self) -> None:
self.health_worker = Interval(self.recycle_interval, health_main, args=(self,))
self.health_worker.daemon = True
self.health_worker.name = "onnx-web health"
logger.debug("starting health worker")
self.health_worker.start()
def create_logger_worker(self) -> None: def create_logger_worker(self) -> None:
logger_thread = Thread( self.logger_worker = Thread(
name="onnx-web logger", name="onnx-web logger",
target=logger_main, target=logger_main,
args=( args=(
@ -127,25 +141,18 @@ class DevicePoolExecutor:
), ),
daemon=True, daemon=True,
) )
self.threads["logger"] = logger_thread
logger.debug("starting logger worker") logger.debug("starting logger worker")
logger_thread.start() self.logger_worker.start()
def create_progress_worker(self) -> None: def create_progress_worker(self) -> None:
progress_thread = Thread( self.progress_worker = Interval(
name="onnx-web progress", self.progress_interval, progress_main, args=(self, self.progress)
target=progress_main,
args=(
self,
self.progress,
),
daemon=True,
) )
self.threads["progress"] = progress_thread self.progress_worker.daemon = True
self.progress_worker.name = "onnx-web progress"
logger.debug("starting progress worker") logger.debug("starting progress worker")
progress_thread.start() self.progress_worker.start()
def get_job_context(self, key: str) -> WorkerContext: def get_job_context(self, key: str) -> WorkerContext:
device, _progress = self.running_jobs[key] device, _progress = self.running_jobs[key]
@ -225,10 +232,11 @@ class DevicePoolExecutor:
def join(self): def join(self):
logger.info("stopping worker pool") logger.info("stopping worker pool")
with self.recycle: with self.rlock:
logger.debug("closing queues") logger.debug("closing queues")
self.logs.close() self.logs.close()
self.progress.close() for queue in self.progress.values():
queue.close()
for queue in self.pending.values(): for queue in self.pending.values():
queue.close() queue.close()
@ -250,9 +258,14 @@ class DevicePoolExecutor:
else: else:
logger.debug("worker for device %s has died", device) logger.debug("worker for device %s has died", device)
for name, thread in self.threads.items(): logger.debug("stopping progress worker")
logger.debug("stopping worker %s for thread %s", thread.ident, name) self.progress_worker.join(self.join_timeout)
thread.join(self.join_timeout)
logger.debug("stopping logger worker")
self.logger_worker.join(self.join_timeout)
logger.debug("stopping health worker")
self.health_worker.join(self.join_timeout)
logger.debug("worker pool stopped") logger.debug("worker pool stopped")
@ -276,7 +289,7 @@ class DevicePoolExecutor:
def recycle(self): def recycle(self):
logger.debug("recycling worker pool") logger.debug("recycling worker pool")
with self.recycle: with self.rlock:
self.join_leaking() self.join_leaking()
needs_restart = [] needs_restart = []
@ -318,22 +331,14 @@ class DevicePoolExecutor:
self.create_device_worker(device) self.create_device_worker(device)
self.total_jobs[device.device] = 0 self.total_jobs[device.device] = 0
if self.threads["logger"].is_alive(): if self.logger_worker.is_alive():
logger.debug("logger worker is running") logger.debug("logger worker is running")
if self.logs.full():
logger.warning("logger queue is full, restarting worker")
self.threads["logger"].join(self.join_timeout)
self.create_logger_worker()
else: else:
logger.warning("restarting logger worker") logger.warning("restarting logger worker")
self.create_logger_worker() self.create_logger_worker()
if self.threads["progress"].is_alive(): if self.progress_worker.is_alive():
logger.debug("progress worker is running") logger.debug("progress worker is running")
if self.progress.full():
logger.warning("progress queue is full, restarting worker")
self.threads["progress"].join(self.join_timeout)
self.create_progress_worker()
else: else:
logger.warning("restarting progress worker") logger.warning("restarting progress worker")
self.create_progress_worker() self.create_progress_worker()
@ -366,7 +371,7 @@ class DevicePoolExecutor:
# recycle before attempting to run # recycle before attempting to run
logger.debug("job count for device %s: %s", device, self.total_jobs[device]) logger.debug("job count for device %s: %s", device, self.total_jobs[device])
self.recycle() self.rlock()
# build and queue job # build and queue job
job = JobCommand(key, device, fn, args, kwargs) job = JobCommand(key, device, fn, args, kwargs)
@ -443,15 +448,28 @@ class DevicePoolExecutor:
self.context[progress.device].set_cancel() self.context[progress.device].set_cancel()
def logger_main(pool: DevicePoolExecutor, logs: Queue): def health_main(pool: DevicePoolExecutor):
logger.trace("checking in from health worker thread")
pool.recycle()
if pool.logs.full():
logger.warning("logger queue is full, restarting worker")
pool.logger_worker.join(pool.join_timeout)
pool.create_logger_worker()
if any([queue.full() for queue in pool.progress.values()]):
logger.warning("progress queue is full, restarting worker")
pool.progress_worker.join(pool.join_timeout)
pool.create_progress_worker()
def logger_main(pool: DevicePoolExecutor, logs: "Queue[str]"):
logger.trace("checking in from logger worker thread") logger.trace("checking in from logger worker thread")
while True: while True:
try: try:
job = logs.get(timeout=(pool.join_timeout / 2)) msg = logs.get(pool.join_timeout / 2)
with open("worker.log", "w") as f: logger.debug("received logs from worker: %s", msg)
logger.info("got log: %s", job)
f.write(str(job) + "\n\n")
except Empty: except Empty:
pass pass
except ValueError: except ValueError:
@ -460,15 +478,21 @@ def logger_main(pool: DevicePoolExecutor, logs: Queue):
logger.exception("error in log worker") logger.exception("error in log worker")
def progress_main(pool: DevicePoolExecutor, queue: "Queue[ProgressCommand]"): def progress_main(
pool: DevicePoolExecutor, queues: Dict[str, "Queue[ProgressCommand]"]
):
logger.trace("checking in from progress worker thread") logger.trace("checking in from progress worker thread")
while True: for device, queue in queues.items():
try: try:
progress = queue.get(timeout=(pool.join_timeout / 2)) progress = queue.get_nowait()
pool.update_job(progress) while progress is not None:
pool.update_job(progress)
progress = queue.get_nowait()
except Empty: except Empty:
logger.trace("empty queue in progress worker for device %s", device)
pass pass
except ValueError: except ValueError as e:
logger.debug("value error in progress worker for device %s: %s", device, e)
break break
except Exception: except Exception:
logger.exception("error in progress worker") logger.exception("error in progress worker for device %s", device)

View File

@ -0,0 +1,11 @@
from threading import Timer
class Interval(Timer):
"""
From https://stackoverflow.com/a/48741004
"""
def run(self):
while not self.finished.wait(self.interval):
self.function(*self.args, **self.kwargs)