fix(api): maintain list of pending jobs
This commit is contained in:
parent
588c8c7fdb
commit
15b6e036e1
|
@ -27,6 +27,7 @@ class ProgressCommand:
|
|||
|
||||
|
||||
class JobCommand:
|
||||
device: str
|
||||
name: str
|
||||
fn: Callable[..., None]
|
||||
args: Any
|
||||
|
@ -35,6 +36,7 @@ class JobCommand:
|
|||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
device: str,
|
||||
fn: Callable[..., None],
|
||||
args: Any,
|
||||
kwargs: dict[str, Any],
|
||||
|
|
|
@ -24,14 +24,15 @@ class DevicePoolExecutor:
|
|||
|
||||
leaking: List[Tuple[str, Process]]
|
||||
context: Dict[str, WorkerContext] # Device -> Context
|
||||
current: Dict[str, "Value[int]"]
|
||||
current: Dict[str, "Value[int]"] # Device -> pid
|
||||
pending: Dict[str, "Queue[JobCommand]"]
|
||||
threads: Dict[str, Thread]
|
||||
workers: Dict[str, Process]
|
||||
|
||||
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
|
||||
cancelled_jobs: List[str]
|
||||
finished_jobs: List[ProgressCommand]
|
||||
pending_jobs: List[JobCommand]
|
||||
running_jobs: Dict[str, ProgressCommand] # Device -> job progress
|
||||
total_jobs: Dict[str, int] # Device -> job count
|
||||
|
||||
logs: "Queue[str]"
|
||||
|
@ -57,7 +58,7 @@ class DevicePoolExecutor:
|
|||
self.threads = {}
|
||||
self.workers = {}
|
||||
|
||||
self.active_jobs = {}
|
||||
self.running_jobs = {}
|
||||
self.cancelled_jobs = []
|
||||
self.finished_jobs = []
|
||||
self.total_jobs = {}
|
||||
|
@ -139,31 +140,12 @@ class DevicePoolExecutor:
|
|||
logger_thread.start()
|
||||
|
||||
def create_progress_worker(self) -> None:
|
||||
def update_job(progress: ProgressCommand):
|
||||
if progress.finished:
|
||||
logger.info("job has finished: %s", progress.job)
|
||||
self.finished_jobs.append(progress)
|
||||
del self.active_jobs[progress.job]
|
||||
self.join_leaking()
|
||||
else:
|
||||
logger.debug(
|
||||
"progress update for job: %s to %s", progress.job, progress.progress
|
||||
)
|
||||
self.active_jobs[progress.job] = progress
|
||||
if progress.job in self.cancelled_jobs:
|
||||
logger.debug(
|
||||
"setting flag for cancelled job: %s on %s",
|
||||
progress.job,
|
||||
progress.device,
|
||||
)
|
||||
self.context[progress.device].set_cancel()
|
||||
|
||||
def progress_worker(queue: "Queue[ProgressCommand]"):
|
||||
logger.trace("checking in from progress worker thread")
|
||||
while True:
|
||||
try:
|
||||
progress = queue.get(timeout=(self.join_timeout / 2))
|
||||
update_job(progress)
|
||||
self.update_job(progress)
|
||||
except Empty:
|
||||
pass
|
||||
except ValueError:
|
||||
|
@ -183,7 +165,7 @@ class DevicePoolExecutor:
|
|||
progress_thread.start()
|
||||
|
||||
def get_job_context(self, key: str) -> WorkerContext:
|
||||
device, _progress = self.active_jobs[key]
|
||||
device, _progress = self.running_jobs[key]
|
||||
return self.context[device]
|
||||
|
||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||
|
@ -217,31 +199,43 @@ class DevicePoolExecutor:
|
|||
logger.debug("cannot cancel finished job: %s", key)
|
||||
return False
|
||||
|
||||
if key not in self.active_jobs:
|
||||
for job in self.pending_jobs:
|
||||
if job.name == key:
|
||||
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != key]
|
||||
logger.info("cancelled pending job: %s", key)
|
||||
return True
|
||||
|
||||
if key not in self.running_jobs:
|
||||
logger.debug("cancelled job is not active: %s", key)
|
||||
else:
|
||||
job = self.active_jobs[key]
|
||||
job = self.running_jobs[key]
|
||||
logger.info("cancelling job %s, active on device %s", key, job.device)
|
||||
|
||||
self.cancelled_jobs.append(key)
|
||||
return True
|
||||
|
||||
def done(self, key: str) -> Optional[ProgressCommand]:
|
||||
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]:
|
||||
"""
|
||||
Check if a job has been finished and report the last progress update.
|
||||
|
||||
If the job is still active or pending, the first item will be False.
|
||||
If the job is not finished or active, the first item will be None.
|
||||
If the job is still pending, the first item will be True and there will be no ProgressCommand.
|
||||
"""
|
||||
if key in self.running_jobs:
|
||||
logger.debug("checking status for running job: %s", key)
|
||||
return (False, self.running_jobs[key])
|
||||
|
||||
for job in self.finished_jobs:
|
||||
if job.job == key:
|
||||
return job
|
||||
logger.debug("checking status for finished job: %s", key)
|
||||
return (False, job)
|
||||
|
||||
if key not in self.active_jobs:
|
||||
logger.debug("checking status for unknown job: %s", key)
|
||||
return None
|
||||
for job in self.pending_jobs:
|
||||
if job.name == key:
|
||||
logger.debug("checking status for pending job: %s", key)
|
||||
return (True, ProgressCommand(job.name, job.device, False, 0))
|
||||
|
||||
return self.active_jobs[key]
|
||||
logger.trace("checking status for unknown job: %s", key)
|
||||
return (False, None)
|
||||
|
||||
def join(self):
|
||||
logger.info("stopping worker pool")
|
||||
|
@ -355,17 +349,21 @@ class DevicePoolExecutor:
|
|||
self.devices[device_idx],
|
||||
)
|
||||
|
||||
# increment job count before recycling (why tho?)
|
||||
device = self.devices[device_idx].device
|
||||
|
||||
if device in self.total_jobs:
|
||||
self.total_jobs[device] += 1
|
||||
else:
|
||||
self.total_jobs[device] = 1
|
||||
|
||||
# recycle before attempting to run
|
||||
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
||||
self.recycle()
|
||||
|
||||
self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)
|
||||
# build and queue job
|
||||
job = JobCommand(key, device, fn, args, kwargs)
|
||||
self.pending_jobs.append(job)
|
||||
self.pending[device].put(job, block=False)
|
||||
|
||||
def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
|
||||
history = [
|
||||
|
@ -376,7 +374,7 @@ class DevicePoolExecutor:
|
|||
job.cancel,
|
||||
job.error,
|
||||
)
|
||||
for name, job in self.active_jobs.items()
|
||||
for name, job in self.running_jobs.items()
|
||||
]
|
||||
history.extend(
|
||||
[
|
||||
|
@ -391,3 +389,28 @@ class DevicePoolExecutor:
|
|||
]
|
||||
)
|
||||
return history
|
||||
|
||||
def update_job(self, progress: ProgressCommand):
|
||||
if progress.finished:
|
||||
# move from running to finished
|
||||
logger.info("job has finished: %s", progress.job)
|
||||
self.finished_jobs.append(progress)
|
||||
del self.running_jobs[progress.job]
|
||||
self.join_leaking()
|
||||
if progress.job in self.cancelled_jobs:
|
||||
self.cancelled_jobs.remove(progress.job)
|
||||
else:
|
||||
# move from pending to running
|
||||
logger.debug(
|
||||
"progress update for job: %s to %s", progress.job, progress.progress
|
||||
)
|
||||
self.running_jobs[progress.job] = progress
|
||||
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != progress.job]
|
||||
|
||||
if progress.job in self.cancelled_jobs:
|
||||
logger.debug(
|
||||
"setting flag for cancelled job: %s on %s",
|
||||
progress.job,
|
||||
progress.device,
|
||||
)
|
||||
self.context[progress.device].set_cancel()
|
||||
|
|
Loading…
Reference in New Issue