fix(api): track completed jobs for each device worker (#170)
This commit is contained in:
parent
4b77a00ca7
commit
1f3a5f6f3c
|
@ -37,7 +37,9 @@ def pipeline_from_request(
|
||||||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||||
model_path = get_model_path(context, model)
|
model_path = get_model_path(context, model)
|
||||||
scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys())
|
scheduler = get_from_list(
|
||||||
|
request.args, "scheduler", list(pipeline_schedulers.keys())
|
||||||
|
)
|
||||||
|
|
||||||
if scheduler is None:
|
if scheduler is None:
|
||||||
scheduler = get_config_value("scheduler")
|
scheduler = get_config_value("scheduler")
|
||||||
|
|
|
@ -29,7 +29,7 @@ class DevicePoolExecutor:
|
||||||
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
|
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
|
||||||
cancelled_jobs: List[str]
|
cancelled_jobs: List[str]
|
||||||
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
|
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
|
||||||
total_jobs: int
|
total_jobs: Dict[str, int] # Device -> job count
|
||||||
|
|
||||||
logs: "Queue"
|
logs: "Queue"
|
||||||
progress: "Queue[Tuple[str, str, int]]"
|
progress: "Queue[Tuple[str, str, int]]"
|
||||||
|
@ -57,12 +57,13 @@ class DevicePoolExecutor:
|
||||||
self.active_jobs = {}
|
self.active_jobs = {}
|
||||||
self.cancelled_jobs = []
|
self.cancelled_jobs = []
|
||||||
self.finished_jobs = []
|
self.finished_jobs = []
|
||||||
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
|
self.total_jobs = {}
|
||||||
|
|
||||||
self.logs = Queue(self.max_pending_per_worker)
|
self.logs = Queue(self.max_pending_per_worker)
|
||||||
self.progress = Queue(self.max_pending_per_worker)
|
self.progress = Queue(self.max_pending_per_worker)
|
||||||
self.finished = Queue(self.max_pending_per_worker)
|
self.finished = Queue(self.max_pending_per_worker)
|
||||||
|
|
||||||
|
# TODO: these should be part of a start method
|
||||||
self.create_logger_worker()
|
self.create_logger_worker()
|
||||||
self.create_progress_worker()
|
self.create_progress_worker()
|
||||||
self.create_finished_worker()
|
self.create_finished_worker()
|
||||||
|
@ -70,9 +71,6 @@ class DevicePoolExecutor:
|
||||||
for device in devices:
|
for device in devices:
|
||||||
self.create_device_worker(device)
|
self.create_device_worker(device)
|
||||||
|
|
||||||
logger.debug("testing log worker")
|
|
||||||
self.logs.put("testing")
|
|
||||||
|
|
||||||
def create_device_worker(self, device: DeviceParams) -> None:
|
def create_device_worker(self, device: DeviceParams) -> None:
|
||||||
name = device.device
|
name = device.device
|
||||||
|
|
||||||
|
@ -268,7 +266,10 @@ class DevicePoolExecutor:
|
||||||
if worker.is_alive():
|
if worker.is_alive():
|
||||||
logger.debug("stopping worker for device %s", device)
|
logger.debug("stopping worker for device %s", device)
|
||||||
worker.join(self.join_timeout)
|
worker.join(self.join_timeout)
|
||||||
# worker.terminate()
|
if worker.is_alive():
|
||||||
|
logger.warning(
|
||||||
|
"worker for device %s could not be stopped in time", device
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("worker for device %s has died", device)
|
logger.debug("worker for device %s has died", device)
|
||||||
|
|
||||||
|
@ -276,24 +277,43 @@ class DevicePoolExecutor:
|
||||||
logger.debug("stopping worker thread: %s", name)
|
logger.debug("stopping worker thread: %s", name)
|
||||||
thread.join(self.join_timeout)
|
thread.join(self.join_timeout)
|
||||||
|
|
||||||
logger.debug("worker pool fully joined")
|
logger.debug("worker pool stopped")
|
||||||
|
|
||||||
def recycle(self):
|
def recycle(self):
|
||||||
for name, proc in self.workers.items():
|
logger.debug("recycling worker pool")
|
||||||
if proc.is_alive():
|
needs_restart = []
|
||||||
logger.debug("shutting down worker for device %s", name)
|
|
||||||
|
for device, proc in self.workers.items():
|
||||||
|
jobs = self.total_jobs.get(device, 0)
|
||||||
|
if not proc.is_alive():
|
||||||
|
logger.warning("worker for device %s has died", device)
|
||||||
|
needs_restart.append(device)
|
||||||
|
elif jobs > self.max_jobs_per_worker:
|
||||||
|
logger.info(
|
||||||
|
"shutting down worker for device %s after %s jobs", device, jobs
|
||||||
|
)
|
||||||
proc.join(self.join_timeout)
|
proc.join(self.join_timeout)
|
||||||
# proc.terminate()
|
if proc.is_alive():
|
||||||
|
logger.warning(
|
||||||
|
"worker for device %s could not be recycled in time", device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.workers[device] = None
|
||||||
|
del proc
|
||||||
|
needs_restart.append(device)
|
||||||
else:
|
else:
|
||||||
logger.warning("worker for device %s has died", name)
|
logger.debug(
|
||||||
|
"worker for device %s does not need to be recycled", device
|
||||||
|
)
|
||||||
|
|
||||||
self.workers[name] = None
|
logger.debug("starting new workers")
|
||||||
del proc
|
|
||||||
|
|
||||||
logger.info("starting new workers")
|
|
||||||
|
|
||||||
for device in self.devices:
|
for device in self.devices:
|
||||||
self.create_device_worker(device)
|
if device.device in needs_restart:
|
||||||
|
self.create_device_worker(device)
|
||||||
|
self.total_jobs[device.device] = 0
|
||||||
|
|
||||||
|
logger.debug("worker pool recycled")
|
||||||
|
|
||||||
def submit(
|
def submit(
|
||||||
self,
|
self,
|
||||||
|
@ -304,12 +324,6 @@ class DevicePoolExecutor:
|
||||||
needs_device: Optional[DeviceParams] = None,
|
needs_device: Optional[DeviceParams] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.total_jobs += 1
|
|
||||||
logger.debug("pool job count: %s", self.total_jobs)
|
|
||||||
if self.total_jobs > self.max_jobs_per_worker:
|
|
||||||
self.recycle()
|
|
||||||
self.total_jobs = 0
|
|
||||||
|
|
||||||
device_idx = self.get_next_device(needs_device=needs_device)
|
device_idx = self.get_next_device(needs_device=needs_device)
|
||||||
logger.info(
|
logger.info(
|
||||||
"assigning job %s to device %s: %s",
|
"assigning job %s to device %s: %s",
|
||||||
|
@ -319,6 +333,15 @@ class DevicePoolExecutor:
|
||||||
)
|
)
|
||||||
|
|
||||||
device = self.devices[device_idx].device
|
device = self.devices[device_idx].device
|
||||||
|
|
||||||
|
if device in self.total_jobs:
|
||||||
|
self.total_jobs[device] += 1
|
||||||
|
else:
|
||||||
|
self.total_jobs[device] = 1
|
||||||
|
|
||||||
|
logger.debug("device job count: %s", self.total_jobs[device])
|
||||||
|
self.recycle()
|
||||||
|
|
||||||
self.pending[device].put((key, fn, args, kwargs), block=False)
|
self.pending[device].put((key, fn, args, kwargs), block=False)
|
||||||
|
|
||||||
def status(self) -> List[Tuple[str, int, bool, bool]]:
|
def status(self) -> List[Tuple[str, int, bool, bool]]:
|
||||||
|
|
Loading…
Reference in New Issue