1
0
Fork 0

fix(api): track completed jobs for each device worker (#170)

This commit is contained in:
Sean Sube 2023-03-01 19:09:18 -06:00
parent 4b77a00ca7
commit 1f3a5f6f3c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 49 additions and 24 deletions

View File

@ -37,7 +37,9 @@ def pipeline_from_request(
lpw = get_not_empty(request.args, "lpw", "false") == "true"
model = get_not_empty(request.args, "model", get_config_value("model"))
model_path = get_model_path(context, model)
scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys())
scheduler = get_from_list(
request.args, "scheduler", list(pipeline_schedulers.keys())
)
if scheduler is None:
scheduler = get_config_value("scheduler")

View File

@ -29,7 +29,7 @@ class DevicePoolExecutor:
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
cancelled_jobs: List[str]
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
total_jobs: int
total_jobs: Dict[str, int] # Device -> job count
logs: "Queue"
progress: "Queue[Tuple[str, str, int]]"
@ -57,12 +57,13 @@ class DevicePoolExecutor:
self.active_jobs = {}
self.cancelled_jobs = []
self.finished_jobs = []
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
self.total_jobs = {}
self.logs = Queue(self.max_pending_per_worker)
self.progress = Queue(self.max_pending_per_worker)
self.finished = Queue(self.max_pending_per_worker)
# TODO: these should be part of a start method
self.create_logger_worker()
self.create_progress_worker()
self.create_finished_worker()
@ -70,9 +71,6 @@ class DevicePoolExecutor:
for device in devices:
self.create_device_worker(device)
logger.debug("testing log worker")
self.logs.put("testing")
def create_device_worker(self, device: DeviceParams) -> None:
name = device.device
@ -268,7 +266,10 @@ class DevicePoolExecutor:
if worker.is_alive():
logger.debug("stopping worker for device %s", device)
worker.join(self.join_timeout)
# worker.terminate()
if worker.is_alive():
logger.warning(
"worker for device %s could not be stopped in time", device
)
else:
logger.debug("worker for device %s has died", device)
@ -276,24 +277,43 @@ class DevicePoolExecutor:
logger.debug("stopping worker thread: %s", name)
thread.join(self.join_timeout)
logger.debug("worker pool fully joined")
logger.debug("worker pool stopped")
def recycle(self):
for name, proc in self.workers.items():
if proc.is_alive():
logger.debug("shutting down worker for device %s", name)
logger.debug("recycling worker pool")
needs_restart = []
for device, proc in self.workers.items():
jobs = self.total_jobs.get(device, 0)
if not proc.is_alive():
logger.warning("worker for device %s has died", device)
needs_restart.append(device)
elif jobs > self.max_jobs_per_worker:
logger.info(
"shutting down worker for device %s after %s jobs", device, jobs
)
proc.join(self.join_timeout)
# proc.terminate()
if proc.is_alive():
logger.warning(
"worker for device %s could not be recycled in time", device
)
self.workers[device] = None
del proc
needs_restart.append(device)
else:
logger.warning("worker for device %s has died", name)
logger.debug(
"worker for device %s does not need to be recycled", device
)
self.workers[name] = None
del proc
logger.info("starting new workers")
logger.debug("starting new workers")
for device in self.devices:
self.create_device_worker(device)
if device.device in needs_restart:
self.create_device_worker(device)
self.total_jobs[device.device] = 0
logger.debug("worker pool recycled")
def submit(
self,
@ -304,12 +324,6 @@ class DevicePoolExecutor:
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
self.total_jobs += 1
logger.debug("pool job count: %s", self.total_jobs)
if self.total_jobs > self.max_jobs_per_worker:
self.recycle()
self.total_jobs = 0
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
"assigning job %s to device %s: %s",
@ -319,6 +333,15 @@ class DevicePoolExecutor:
)
device = self.devices[device_idx].device
if device in self.total_jobs:
self.total_jobs[device] += 1
else:
self.total_jobs[device] = 1
logger.debug("device job count: %s", self.total_jobs[device])
self.recycle()
self.pending[device].put((key, fn, args, kwargs), block=False)
def status(self) -> List[Tuple[str, int, bool, bool]]: