1
0
Fork 0

fix(api): maintain list of pending jobs

This commit is contained in:
Sean Sube 2023-03-18 17:15:18 -05:00
parent 588c8c7fdb
commit 15b6e036e1
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 63 additions and 38 deletions

View File

@ -27,6 +27,7 @@ class ProgressCommand:
class JobCommand:
device: str
name: str
fn: Callable[..., None]
args: Any
@ -35,6 +36,7 @@ class JobCommand:
def __init__(
self,
name: str,
device: str,
fn: Callable[..., None],
args: Any,
kwargs: dict[str, Any],

View File

@ -23,15 +23,16 @@ class DevicePoolExecutor:
join_timeout: float
leaking: List[Tuple[str, Process]]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"] # Device -> pid
pending: Dict[str, "Queue[JobCommand]"]
threads: Dict[str, Thread]
workers: Dict[str, Process]
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
cancelled_jobs: List[str]
finished_jobs: List[ProgressCommand]
pending_jobs: List[JobCommand]
running_jobs: Dict[str, ProgressCommand] # Device -> job progress
total_jobs: Dict[str, int] # Device -> job count
logs: "Queue[str]"
@ -57,7 +58,7 @@ class DevicePoolExecutor:
self.threads = {}
self.workers = {}
self.active_jobs = {}
self.running_jobs = {}
self.cancelled_jobs = []
self.finished_jobs = []
self.total_jobs = {}
@ -139,31 +140,12 @@ class DevicePoolExecutor:
logger_thread.start()
def create_progress_worker(self) -> None:
def update_job(progress: ProgressCommand):
if progress.finished:
logger.info("job has finished: %s", progress.job)
self.finished_jobs.append(progress)
del self.active_jobs[progress.job]
self.join_leaking()
else:
logger.debug(
"progress update for job: %s to %s", progress.job, progress.progress
)
self.active_jobs[progress.job] = progress
if progress.job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s",
progress.job,
progress.device,
)
self.context[progress.device].set_cancel()
def progress_worker(queue: "Queue[ProgressCommand]"):
logger.trace("checking in from progress worker thread")
while True:
try:
progress = queue.get(timeout=(self.join_timeout / 2))
update_job(progress)
self.update_job(progress)
except Empty:
pass
except ValueError:
@ -183,7 +165,7 @@ class DevicePoolExecutor:
progress_thread.start()
def get_job_context(self, key: str) -> WorkerContext:
device, _progress = self.active_jobs[key]
device, _progress = self.running_jobs[key]
return self.context[device]
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
@ -217,31 +199,43 @@ class DevicePoolExecutor:
logger.debug("cannot cancel finished job: %s", key)
return False
if key not in self.active_jobs:
for job in self.pending_jobs:
if job.name == key:
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != key]
logger.info("cancelled pending job: %s", key)
return True
if key not in self.running_jobs:
logger.debug("cancelled job is not active: %s", key)
else:
job = self.active_jobs[key]
job = self.running_jobs[key]
logger.info("cancelling job %s, active on device %s", key, job.device)
self.cancelled_jobs.append(key)
return True
def done(self, key: str) -> Optional[ProgressCommand]:
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]:
"""
Check if a job has been finished and report the last progress update.
If the job is still active or pending, the first item will be False.
If the job is not finished or active, the first item will be None.
If the job is still pending, the first item will be True and there will be no ProgressCommand.
"""
if key in self.running_jobs:
logger.debug("checking status for running job: %s", key)
return (False, self.running_jobs[key])
for job in self.finished_jobs:
if job.job == key:
return job
logger.debug("checking status for finished job: %s", key)
return (False, job)
if key not in self.active_jobs:
logger.debug("checking status for unknown job: %s", key)
return None
for job in self.pending_jobs:
if job.name == key:
logger.debug("checking status for pending job: %s", key)
return (True, ProgressCommand(job.name, job.device, False, 0))
return self.active_jobs[key]
logger.trace("checking status for unknown job: %s", key)
return (False, None)
def join(self):
logger.info("stopping worker pool")
@ -355,17 +349,21 @@ class DevicePoolExecutor:
self.devices[device_idx],
)
# increment job count before recycling (why tho?)
device = self.devices[device_idx].device
if device in self.total_jobs:
self.total_jobs[device] += 1
else:
self.total_jobs[device] = 1
# recycle before attempting to run
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
self.recycle()
self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)
# build and queue job
job = JobCommand(key, device, fn, args, kwargs)
self.pending_jobs.append(job)
self.pending[device].put(job, block=False)
def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
history = [
@ -376,7 +374,7 @@ class DevicePoolExecutor:
job.cancel,
job.error,
)
for name, job in self.active_jobs.items()
for name, job in self.running_jobs.items()
]
history.extend(
[
@ -391,3 +389,28 @@ class DevicePoolExecutor:
]
)
return history
def update_job(self, progress: ProgressCommand):
if progress.finished:
# move from running to finished
logger.info("job has finished: %s", progress.job)
self.finished_jobs.append(progress)
del self.running_jobs[progress.job]
self.join_leaking()
if progress.job in self.cancelled_jobs:
self.cancelled_jobs.remove(progress.job)
else:
# move from pending to running
logger.debug(
"progress update for job: %s to %s", progress.job, progress.progress
)
self.running_jobs[progress.job] = progress
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != progress.job]
if progress.job in self.cancelled_jobs:
logger.debug(
"setting flag for cancelled job: %s on %s",
progress.job,
progress.device,
)
self.context[progress.device].set_cancel()