diff --git a/api/onnx_web/worker/command.py b/api/onnx_web/worker/command.py index 47eb75d3..5a05e246 100644 --- a/api/onnx_web/worker/command.py +++ b/api/onnx_web/worker/command.py @@ -27,6 +27,7 @@ class ProgressCommand: class JobCommand: + device: str name: str fn: Callable[..., None] args: Any @@ -35,6 +36,7 @@ class JobCommand: def __init__( self, name: str, + device: str, fn: Callable[..., None], args: Any, kwargs: dict[str, Any], diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index dc4b95fc..22a74d0a 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -23,15 +23,16 @@ class DevicePoolExecutor: join_timeout: float leaking: List[Tuple[str, Process]] - context: Dict[str, WorkerContext] # Device -> Context - current: Dict[str, "Value[int]"] + context: Dict[str, WorkerContext] # Device -> Context + current: Dict[str, "Value[int]"] # Device -> pid pending: Dict[str, "Queue[JobCommand]"] threads: Dict[str, Thread] workers: Dict[str, Process] - active_jobs: Dict[str, ProgressCommand] # Device -> job progress cancelled_jobs: List[str] finished_jobs: List[ProgressCommand] + pending_jobs: List[JobCommand] + running_jobs: Dict[str, ProgressCommand] # Device -> job progress total_jobs: Dict[str, int] # Device -> job count logs: "Queue[str]" @@ -57,7 +58,7 @@ class DevicePoolExecutor: self.threads = {} self.workers = {} - self.active_jobs = {} + self.running_jobs = {} self.cancelled_jobs = [] self.finished_jobs = [] self.total_jobs = {} @@ -139,31 +140,12 @@ class DevicePoolExecutor: logger_thread.start() def create_progress_worker(self) -> None: - def update_job(progress: ProgressCommand): - if progress.finished: - logger.info("job has finished: %s", progress.job) - self.finished_jobs.append(progress) - del self.active_jobs[progress.job] - self.join_leaking() - else: - logger.debug( - "progress update for job: %s to %s", progress.job, progress.progress - ) - self.active_jobs[progress.job] = progress - if progress.job in self.cancelled_jobs: - logger.debug( - "setting flag for cancelled job: %s on %s", - progress.job, - progress.device, - ) - self.context[progress.device].set_cancel() - def progress_worker(queue: "Queue[ProgressCommand]"): logger.trace("checking in from progress worker thread") while True: try: progress = queue.get(timeout=(self.join_timeout / 2)) - update_job(progress) + self.update_job(progress) except Empty: pass except ValueError: @@ -183,7 +165,7 @@ class DevicePoolExecutor: progress_thread.start() def get_job_context(self, key: str) -> WorkerContext: - device, _progress = self.active_jobs[key] + device, _progress = self.running_jobs[key] return self.context[device] def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: @@ -217,31 +199,43 @@ class DevicePoolExecutor: logger.debug("cannot cancel finished job: %s", key) return False - if key not in self.active_jobs: + for job in self.pending_jobs: + if job.name == key: + self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != key] + logger.info("cancelled pending job: %s", key) + return True + + if key not in self.running_jobs: logger.debug("cancelled job is not active: %s", key) else: - job = self.active_jobs[key] + job = self.running_jobs[key] logger.info("cancelling job %s, active on device %s", key, job.device) self.cancelled_jobs.append(key) return True - def done(self, key: str) -> Optional[ProgressCommand]: + def done(self, key: str) -> Tuple[bool, Optional[ProgressCommand]]: """ Check if a job has been finished and report the last progress update. - If the job is still active or pending, the first item will be False. - If the job is not finished or active, the first item will be None. + If the job is still pending, the first item will be True and there will be no ProgressCommand. """ + if key in self.running_jobs: + logger.debug("checking status for running job: %s", key) + return (False, self.running_jobs[key]) + for job in self.finished_jobs: if job.job == key: - return job + logger.debug("checking status for finished job: %s", key) + return (False, job) - if key not in self.active_jobs: - logger.debug("checking status for unknown job: %s", key) - return None + for job in self.pending_jobs: + if job.name == key: + logger.debug("checking status for pending job: %s", key) + return (True, ProgressCommand(job.name, job.device, False, 0)) - return self.active_jobs[key] + logger.trace("checking status for unknown job: %s", key) + return (False, None) def join(self): logger.info("stopping worker pool") @@ -355,17 +349,21 @@ class DevicePoolExecutor: self.devices[device_idx], ) + # increment job count before recycling (why tho?) device = self.devices[device_idx].device - if device in self.total_jobs: self.total_jobs[device] += 1 else: self.total_jobs[device] = 1 + # recycle before attempting to run logger.debug("job count for device %s: %s", device, self.total_jobs[device]) self.recycle() - self.pending[device].put(JobCommand(key, fn, args, kwargs), block=False) + # build and queue job + job = JobCommand(key, device, fn, args, kwargs) + self.pending_jobs.append(job) + self.pending[device].put(job, block=False) def status(self) -> List[Tuple[str, int, bool, bool, bool]]: history = [ @@ -376,7 +374,7 @@ class DevicePoolExecutor: job.cancel, job.error, ) - for name, job in self.active_jobs.items() + for name, job in self.running_jobs.items() ] history.extend( [ @@ -391,3 +389,28 @@ class DevicePoolExecutor: ] ) return history + + def update_job(self, progress: ProgressCommand): + if progress.finished: + # move from running to finished + logger.info("job has finished: %s", progress.job) + self.finished_jobs.append(progress) + del self.running_jobs[progress.job] + self.join_leaking() + if progress.job in self.cancelled_jobs: + self.cancelled_jobs.remove(progress.job) + else: + # move from pending to running + logger.debug( + "progress update for job: %s to %s", progress.job, progress.progress + ) + self.running_jobs[progress.job] = progress + self.pending_jobs[:] = [job for job in self.pending_jobs if job.name != progress.job] + + if progress.job in self.cancelled_jobs: + logger.debug( + "setting flag for cancelled job: %s on %s", + progress.job, + progress.device, + ) + self.context[progress.device].set_cancel()