fix(api): maintain list of pending jobs
This commit is contained in:
parent
588c8c7fdb
commit
15b6e036e1
|
@ -27,6 +27,7 @@ class ProgressCommand:
|
||||||
|
|
||||||
|
|
||||||
class JobCommand:
|
class JobCommand:
|
||||||
|
device: str
|
||||||
name: str
|
name: str
|
||||||
fn: Callable[..., None]
|
fn: Callable[..., None]
|
||||||
args: Any
|
args: Any
|
||||||
|
@ -35,6 +36,7 @@ class JobCommand:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
|
device: str,
|
||||||
fn: Callable[..., None],
|
fn: Callable[..., None],
|
||||||
args: Any,
|
args: Any,
|
||||||
kwargs: dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
|
|
|
@ -23,15 +23,16 @@ class DevicePoolExecutor:
|
||||||
join_timeout: float
|
join_timeout: float
|
||||||
|
|
||||||
leaking: List[Tuple[str, Process]]
|
leaking: List[Tuple[str, Process]]
|
||||||
context: Dict[str, WorkerContext] # Device -> Context
|
context: Dict[str, WorkerContext] # Device -> Context
|
||||||
current: Dict[str, "Value[int]"]
|
current: Dict[str, "Value[int]"] # Device -> pid
|
||||||
pending: Dict[str, "Queue[JobCommand]"]
|
pending: Dict[str, "Queue[JobCommand]"]
|
||||||
threads: Dict[str, Thread]
|
threads: Dict[str, Thread]
|
||||||
workers: Dict[str, Process]
|
workers: Dict[str, Process]
|
||||||
|
|
||||||
active_jobs: Dict[str, ProgressCommand] # Device -> job progress
|
|
||||||
cancelled_jobs: List[str]
|
cancelled_jobs: List[str]
|
||||||
finished_jobs: List[ProgressCommand]
|
finished_jobs: List[ProgressCommand]
|
||||||
|
pending_jobs: List[JobCommand]
|
||||||
|
running_jobs: Dict[str, ProgressCommand] # Device -> job progress
|
||||||
total_jobs: Dict[str, int] # Device -> job count
|
total_jobs: Dict[str, int] # Device -> job count
|
||||||
|
|
||||||
logs: "Queue[str]"
|
logs: "Queue[str]"
|
||||||
|
@ -57,7 +58,7 @@ class DevicePoolExecutor:
|
||||||
self.threads = {}
|
self.threads = {}
|
||||||
self.workers = {}
|
self.workers = {}
|
||||||
|
|
||||||
self.active_jobs = {}
|
self.running_jobs = {}
|
||||||
self.cancelled_jobs = []
|
self.cancelled_jobs = []
|
||||||
self.finished_jobs = []
|
self.finished_jobs = []
|
||||||
self.total_jobs = {}
|
self.total_jobs = {}
|
||||||
|
@ -139,31 +140,12 @@ class DevicePoolExecutor:
|
||||||
logger_thread.start()
|
logger_thread.start()
|
||||||
|
|
||||||
def create_progress_worker(self) -> None:
|
def create_progress_worker(self) -> None:
|
||||||
def update_job(progress: ProgressCommand):
|
|
||||||
if progress.finished:
|
|
||||||
logger.info("job has finished: %s", progress.job)
|
|
||||||
self.finished_jobs.append(progress)
|
|
||||||
del self.active_jobs[progress.job]
|
|
||||||
self.join_leaking()
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"progress update for job: %s to %s", progress.job, progress.progress
|
|
||||||
)
|
|
||||||
self.active_jobs[progress.job] = progress
|
|
||||||
if progress.job in self.cancelled_jobs:
|
|
||||||
logger.debug(
|
|
||||||
"setting flag for cancelled job: %s on %s",
|
|
||||||
progress.job,
|
|
||||||
progress.device,
|
|
||||||
)
|
|
||||||
self.context[progress.device].set_cancel()
|
|
||||||
|
|
||||||
def progress_worker(queue: "Queue[ProgressCommand]"):
|
def progress_worker(queue: "Queue[ProgressCommand]"):
|
||||||
logger.trace("checking in from progress worker thread")
|
logger.trace("checking in from progress worker thread")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
progress = queue.get(timeout=(self.join_timeout / 2))
|
progress = queue.get(timeout=(self.join_timeout / 2))
|
||||||
update_job(progress)
|
self.update_job(progress)
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -183,7 +165,7 @@ class DevicePoolExecutor:
|
||||||
progress_thread.start()
|
progress_thread.start()
|
||||||
|
|
||||||
def get_job_context(self, key: str) -> WorkerContext:
|
def get_job_context(self, key: str) -> WorkerContext:
|
||||||
device, _progress = self.active_jobs[key]
|
device, _progress = self.running_jobs[key]
|
||||||
return self.context[device]
|
return self.context[device]
|
||||||
|
|
||||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||||
|
@ -217,31 +199,43 @@ class DevicePoolExecutor:
|
||||||
logger.debug("cannot cancel finished job: %s", key)
|
logger.debug("cannot cancel finished job: %s", key)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if key not in self.active_jobs:
|
for job in self.pending_jobs:
|
||||||
|
if job.name == key:
|
||||||
|
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != key]
|
||||||
|
logger.info("cancelled pending job: %s", key)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if key not in self.running_jobs:
|
||||||
logger.debug("cancelled job is not active: %s", key)
|
logger.debug("cancelled job is not active: %s", key)
|
||||||
else:
|
else:
|
||||||
job = self.active_jobs[key]
|
job = self.running_jobs[key]
|
||||||
logger.info("cancelling job %s, active on device %s", key, job.device)
|
logger.info("cancelling job %s, active on device %s", key, job.device)
|
||||||
|
|
||||||
self.cancelled_jobs.append(key)
|
self.cancelled_jobs.append(key)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def done(self, key: str) -> Optional[ProgressCommand]:
|
def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]:
|
||||||
"""
|
"""
|
||||||
Check if a job has been finished and report the last progress update.
|
Check if a job has been finished and report the last progress update.
|
||||||
|
|
||||||
If the job is still active or pending, the first item will be False.
|
If the job is still pending, the first item will be True and there will be no ProgressCommand.
|
||||||
If the job is not finished or active, the first item will be None.
|
|
||||||
"""
|
"""
|
||||||
|
if key in self.running_jobs:
|
||||||
|
logger.debug("checking status for running job: %s", key)
|
||||||
|
return (False, self.running_jobs[key])
|
||||||
|
|
||||||
for job in self.finished_jobs:
|
for job in self.finished_jobs:
|
||||||
if job.job == key:
|
if job.job == key:
|
||||||
return job
|
logger.debug("checking status for finished job: %s", key)
|
||||||
|
return (False, job)
|
||||||
|
|
||||||
if key not in self.active_jobs:
|
for job in self.pending_jobs:
|
||||||
logger.debug("checking status for unknown job: %s", key)
|
if job.name == key:
|
||||||
return None
|
logger.debug("checking status for pending job: %s", key)
|
||||||
|
return (True, ProgressCommand(job.name, job.device, False, 0))
|
||||||
|
|
||||||
return self.active_jobs[key]
|
logger.trace("checking status for unknown job: %s", key)
|
||||||
|
return (False, None)
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
logger.info("stopping worker pool")
|
logger.info("stopping worker pool")
|
||||||
|
@ -355,17 +349,21 @@ class DevicePoolExecutor:
|
||||||
self.devices[device_idx],
|
self.devices[device_idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# increment job count before recycling (why tho?)
|
||||||
device = self.devices[device_idx].device
|
device = self.devices[device_idx].device
|
||||||
|
|
||||||
if device in self.total_jobs:
|
if device in self.total_jobs:
|
||||||
self.total_jobs[device] += 1
|
self.total_jobs[device] += 1
|
||||||
else:
|
else:
|
||||||
self.total_jobs[device] = 1
|
self.total_jobs[device] = 1
|
||||||
|
|
||||||
|
# recycle before attempting to run
|
||||||
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
logger.debug("job count for device %s: %s", device, self.total_jobs[device])
|
||||||
self.recycle()
|
self.recycle()
|
||||||
|
|
||||||
self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False)
|
# build and queue job
|
||||||
|
job = JobCommand(key, device, fn, args, kwargs)
|
||||||
|
self.pending_jobs.append(job)
|
||||||
|
self.pending[device].put(job, block=False)
|
||||||
|
|
||||||
def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
|
def status(self) -> List[Tuple[str, int, bool, bool, bool]]:
|
||||||
history = [
|
history = [
|
||||||
|
@ -376,7 +374,7 @@ class DevicePoolExecutor:
|
||||||
job.cancel,
|
job.cancel,
|
||||||
job.error,
|
job.error,
|
||||||
)
|
)
|
||||||
for name, job in self.active_jobs.items()
|
for name, job in self.running_jobs.items()
|
||||||
]
|
]
|
||||||
history.extend(
|
history.extend(
|
||||||
[
|
[
|
||||||
|
@ -391,3 +389,28 @@ class DevicePoolExecutor:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
def update_job(self, progress: ProgressCommand):
|
||||||
|
if progress.finished:
|
||||||
|
# move from running to finished
|
||||||
|
logger.info("job has finished: %s", progress.job)
|
||||||
|
self.finished_jobs.append(progress)
|
||||||
|
del self.running_jobs[progress.job]
|
||||||
|
self.join_leaking()
|
||||||
|
if progress.job in self.cancelled_jobs:
|
||||||
|
self.cancelled_jobs.remove(progress.job)
|
||||||
|
else:
|
||||||
|
# move from pending to running
|
||||||
|
logger.debug(
|
||||||
|
"progress update for job: %s to %s", progress.job, progress.progress
|
||||||
|
)
|
||||||
|
self.running_jobs[progress.job] = progress
|
||||||
|
self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != progress.job]
|
||||||
|
|
||||||
|
if progress.job in self.cancelled_jobs:
|
||||||
|
logger.debug(
|
||||||
|
"setting flag for cancelled job: %s on %s",
|
||||||
|
progress.job,
|
||||||
|
progress.device,
|
||||||
|
)
|
||||||
|
self.context[progress.device].set_cancel()
|
||||||
|
|
Loading…
Reference in New Issue