from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor from logging import getLogger from multiprocessing import Value from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .params import ( DeviceParams, ) logger = getLogger(__name__) class JobContext: def __init__( self, key: str, devices: List[DeviceParams], cancel: bool = False, device_index: int = -1, progress: int = 0, ): self.key = key self.devices = list(devices) self.cancel = Value('B', cancel) self.device_index = Value('i', device_index) self.progress = Value('I', progress) def is_cancelled(self) -> bool: return self.cancel.value def get_device(self) -> DeviceParams: ''' Get the device assigned to this job. ''' with self.device_index.get_lock(): device_index = self.device_index.value if device_index < 0: raise Exception('job has not been assigned to a device') else: device = self.devices[device_index] logger.debug('job %s assigned to device %s', self.key, device) return device def get_progress(self) -> int: return self.progress.value def get_progress_callback(self) -> Callable[..., None]: def on_progress(step: int, timestep: int, latents: Any): if self.is_cancelled(): raise Exception('job has been cancelled') else: logger.debug('setting progress for job %s to %s', self.key, step) self.set_progress(step) return on_progress def set_cancel(self, cancel: bool = True) -> None: with self.cancel.get_lock(): self.cancel.value = cancel def set_progress(self, progress: int) -> None: with self.progress.get_lock(): self.progress.value = progress class Job: ''' Link a future to its context. ''' def __init__( self, key: str, future: Future, context: JobContext, ): self.context = context self.future = future self.key = key def get_progress(self) -> int: return self.context.get_progress() def set_cancel(self, cancel: bool = True): return self.context.set_cancel(cancel) def set_progress(self, progress: int): return self.context.set_progress(progress) class DevicePoolExecutor: devices: List[str] = None jobs: List[Job] = None pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None): self.devices = devices self.jobs = [] device_count = len(devices) if pool is None: logger.info( 'creating thread pool executor for %s devices: %s', device_count, devices) self.pool = ThreadPoolExecutor(device_count) else: logger.info('using existing pool for %s devices: %s', device_count, devices) self.pool = pool def cancel(self, key: str) -> bool: ''' Cancel a job. If the job has not been started, this will cancel the future and never execute it. If the job has been started, it should be cancelled on the next progress callback. ''' for job in self.jobs: if job.key == key: if job.future.cancel(): return True else: job.set_cancel() def done(self, key: str) -> Tuple[bool, int]: for job in self.jobs: if job.key == key: done = job.future.done() progress = job.get_progress() return (done, progress) logger.warn('checking status for unknown key: %s', key) return (None, 0) def prune(self): self.jobs[:] = [job for job in self.jobs if job.future.done()] def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None: context = JobContext(key, self.devices, device_index=0) future = self.pool.submit(fn, context, *args, **kwargs) job = Job(key, future, context) self.jobs.append(job) def job_done(f: Future): try: f.result() logger.info('job %s finished successfully', key) except Exception as err: logger.warn('job %s failed with an error: %s', key, err) future.add_done_callback(job_done) def status(self) -> Dict[str, Tuple[bool, int]]: return [(job.key, job.future.done(), job.get_progress()) for job in self.jobs]