1
0
Fork 0

fix(api): remove finished jobs from worker pool (#124)

This commit is contained in:
Sean Sube 2023-02-14 17:23:23 -06:00
parent 38f8aa38ee
commit feb4603171
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 41 additions and 2 deletions

View File

@ -107,15 +107,19 @@ class DevicePoolExecutor:
jobs: List[Job] = None jobs: List[Job] = None
next_device: int = 0 next_device: int = 0
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
recent: List[Tuple[str, int]] = None
def __init__( def __init__(
self, self,
devices: List[DeviceParams], devices: List[DeviceParams],
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None, pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None,
recent_limit: int = 10,
): ):
self.devices = devices self.devices = devices
self.jobs = [] self.jobs = []
self.next_device = 0 self.next_device = 0
self.recent = []
self.recent_limit = recent_limit
device_count = len(devices) device_count = len(devices)
if pool is None: if pool is None:
@ -150,10 +154,18 @@ class DevicePoolExecutor:
return False return False
def done(self, key: str) -> Tuple[Optional[bool], int]: def done(self, key: str) -> Tuple[Optional[bool], int]:
for k, progress in self.recent:
if key == k:
return (True, progress)
for job in self.jobs: for job in self.jobs:
if job.key == key: if job.key == key:
done = job.future.done() done = job.future.done()
progress = job.get_progress() progress = job.get_progress()
if done:
self.prune()
return (done, progress) return (done, progress)
logger.warn("checking status for unknown key: %s", key) logger.warn("checking status for unknown key: %s", key)
@ -186,7 +198,21 @@ class DevicePoolExecutor:
return lowest_devices[0] return lowest_devices[0]
def prune(self): def prune(self):
self.jobs[:] = [job for job in self.jobs if job.future.done()] pending_jobs = [job for job in self.jobs if job.future.done()]
logger.debug("pruning %s of %s pending jobs", len(pending_jobs), len(self.jobs))
for job in pending_jobs:
self.recent.append((job.key, job.get_progress()))
try:
self.jobs.remove(job)
except ValueError as e:
logger.warning("error removing pruned job from pending: %s", e)
# self.jobs[:] = [job for job in self.jobs if not job.future.done()]
recent_count = len(self.recent)
if recent_count > self.recent_limit:
logger.debug("pruning %s of %s recent jobs", recent_count - self.recent_limit, recent_count)
self.recent[:] = self.recent[-self.recent_limit :]
def submit( def submit(
self, self,
@ -197,6 +223,7 @@ class DevicePoolExecutor:
needs_device: Optional[DeviceParams] = None, needs_device: Optional[DeviceParams] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
self.prune()
device = self.get_next_device(needs_device=needs_device) device = self.get_next_device(needs_device=needs_device)
logger.info( logger.info(
"assigning job %s to device %s: %s", key, device, self.devices[device] "assigning job %s to device %s: %s", key, device, self.devices[device]
@ -222,7 +249,7 @@ class DevicePoolExecutor:
future.add_done_callback(job_done) future.add_done_callback(job_done)
def status(self) -> List[Tuple[str, int, bool, int]]: def status(self) -> List[Tuple[str, int, bool, int]]:
return [ pending = [
( (
job.key, job.key,
job.context.device_index.value, job.context.device_index.value,
@ -231,3 +258,15 @@ class DevicePoolExecutor:
) )
for job in self.jobs for job in self.jobs
] ]
recent = [
(
key,
None,
True,
progress,
)
for key, progress in self.recent
]
pending.extend(recent)
return pending