fix(api): track completed jobs for each device worker (#170)
This commit is contained in:
parent
4b77a00ca7
commit
1f3a5f6f3c
|
@ -37,7 +37,9 @@ def pipeline_from_request(
|
|||
lpw = get_not_empty(request.args, "lpw", "false") == "true"
|
||||
model = get_not_empty(request.args, "model", get_config_value("model"))
|
||||
model_path = get_model_path(context, model)
|
||||
scheduler = get_from_list(request.args, "scheduler", pipeline_schedulers.keys())
|
||||
scheduler = get_from_list(
|
||||
request.args, "scheduler", list(pipeline_schedulers.keys())
|
||||
)
|
||||
|
||||
if scheduler is None:
|
||||
scheduler = get_config_value("scheduler")
|
||||
|
|
|
@ -29,7 +29,7 @@ class DevicePoolExecutor:
|
|||
active_jobs: Dict[str, Tuple[str, int]] # should be Dict[Device, JobStatus]
|
||||
cancelled_jobs: List[str]
|
||||
finished_jobs: List[Tuple[str, int, bool]] # should be List[JobStatus]
|
||||
total_jobs: int
|
||||
total_jobs: Dict[str, int] # Device -> job count
|
||||
|
||||
logs: "Queue"
|
||||
progress: "Queue[Tuple[str, str, int]]"
|
||||
|
@ -57,12 +57,13 @@ class DevicePoolExecutor:
|
|||
self.active_jobs = {}
|
||||
self.cancelled_jobs = []
|
||||
self.finished_jobs = []
|
||||
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
|
||||
self.total_jobs = {}
|
||||
|
||||
self.logs = Queue(self.max_pending_per_worker)
|
||||
self.progress = Queue(self.max_pending_per_worker)
|
||||
self.finished = Queue(self.max_pending_per_worker)
|
||||
|
||||
# TODO: these should be part of a start method
|
||||
self.create_logger_worker()
|
||||
self.create_progress_worker()
|
||||
self.create_finished_worker()
|
||||
|
@ -70,9 +71,6 @@ class DevicePoolExecutor:
|
|||
for device in devices:
|
||||
self.create_device_worker(device)
|
||||
|
||||
logger.debug("testing log worker")
|
||||
self.logs.put("testing")
|
||||
|
||||
def create_device_worker(self, device: DeviceParams) -> None:
|
||||
name = device.device
|
||||
|
||||
|
@ -268,7 +266,10 @@ class DevicePoolExecutor:
|
|||
if worker.is_alive():
|
||||
logger.debug("stopping worker for device %s", device)
|
||||
worker.join(self.join_timeout)
|
||||
# worker.terminate()
|
||||
if worker.is_alive():
|
||||
logger.warning(
|
||||
"worker for device %s could not be stopped in time", device
|
||||
)
|
||||
else:
|
||||
logger.debug("worker for device %s has died", device)
|
||||
|
||||
|
@ -276,24 +277,43 @@ class DevicePoolExecutor:
|
|||
logger.debug("stopping worker thread: %s", name)
|
||||
thread.join(self.join_timeout)
|
||||
|
||||
logger.debug("worker pool fully joined")
|
||||
logger.debug("worker pool stopped")
|
||||
|
||||
def recycle(self):
|
||||
for name, proc in self.workers.items():
|
||||
if proc.is_alive():
|
||||
logger.debug("shutting down worker for device %s", name)
|
||||
logger.debug("recycling worker pool")
|
||||
needs_restart = []
|
||||
|
||||
for device, proc in self.workers.items():
|
||||
jobs = self.total_jobs.get(device, 0)
|
||||
if not proc.is_alive():
|
||||
logger.warning("worker for device %s has died", device)
|
||||
needs_restart.append(device)
|
||||
elif jobs > self.max_jobs_per_worker:
|
||||
logger.info(
|
||||
"shutting down worker for device %s after %s jobs", device, jobs
|
||||
)
|
||||
proc.join(self.join_timeout)
|
||||
# proc.terminate()
|
||||
if proc.is_alive():
|
||||
logger.warning(
|
||||
"worker for device %s could not be recycled in time", device
|
||||
)
|
||||
|
||||
self.workers[device] = None
|
||||
del proc
|
||||
needs_restart.append(device)
|
||||
else:
|
||||
logger.warning("worker for device %s has died", name)
|
||||
logger.debug(
|
||||
"worker for device %s does not need to be recycled", device
|
||||
)
|
||||
|
||||
self.workers[name] = None
|
||||
del proc
|
||||
|
||||
logger.info("starting new workers")
|
||||
logger.debug("starting new workers")
|
||||
|
||||
for device in self.devices:
|
||||
self.create_device_worker(device)
|
||||
if device.device in needs_restart:
|
||||
self.create_device_worker(device)
|
||||
self.total_jobs[device.device] = 0
|
||||
|
||||
logger.debug("worker pool recycled")
|
||||
|
||||
def submit(
|
||||
self,
|
||||
|
@ -304,12 +324,6 @@ class DevicePoolExecutor:
|
|||
needs_device: Optional[DeviceParams] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.total_jobs += 1
|
||||
logger.debug("pool job count: %s", self.total_jobs)
|
||||
if self.total_jobs > self.max_jobs_per_worker:
|
||||
self.recycle()
|
||||
self.total_jobs = 0
|
||||
|
||||
device_idx = self.get_next_device(needs_device=needs_device)
|
||||
logger.info(
|
||||
"assigning job %s to device %s: %s",
|
||||
|
@ -319,6 +333,15 @@ class DevicePoolExecutor:
|
|||
)
|
||||
|
||||
device = self.devices[device_idx].device
|
||||
|
||||
if device in self.total_jobs:
|
||||
self.total_jobs[device] += 1
|
||||
else:
|
||||
self.total_jobs[device] = 1
|
||||
|
||||
logger.debug("device job count: %s", self.total_jobs[device])
|
||||
self.recycle()
|
||||
|
||||
self.pending[device].put((key, fn, args, kwargs), block=False)
|
||||
|
||||
def status(self) -> List[Tuple[str, int, bool, bool]]:
|
||||
|
|
Loading…
Reference in New Issue