feat(api): distribute jobs to devices using round-robin (#38)
This commit is contained in:
parent
1e38659c80
commit
5e0231c01b
|
@ -92,11 +92,13 @@ class Job:
|
|||
class DevicePoolExecutor:
|
||||
devices: List[DeviceParams] = None
|
||||
jobs: List[Job] = None
|
||||
next_device: int = 0
|
||||
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
||||
|
||||
def __init__(self, devices: List[DeviceParams], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
|
||||
self.devices = devices
|
||||
self.jobs = []
|
||||
self.next_device = 0
|
||||
|
||||
device_count = len(devices)
|
||||
if pool is None:
|
||||
|
@ -131,11 +133,16 @@ class DevicePoolExecutor:
|
|||
logger.warn('checking status for unknown key: %s', key)
|
||||
return (None, 0)
|
||||
|
||||
def get_next_device(self):
|
||||
device = self.next_device
|
||||
self.next_device = (self.next_device + 1) % len(self.devices)
|
||||
return device
|
||||
|
||||
def prune(self):
|
||||
self.jobs[:] = [job for job in self.jobs if job.future.done()]
|
||||
|
||||
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
|
||||
context = JobContext(key, self.devices, device_index=0)
|
||||
context = JobContext(key, self.devices, device_index=self.get_next_device())
|
||||
future = self.pool.submit(fn, context, *args, **kwargs)
|
||||
job = Job(key, future, context)
|
||||
self.jobs.append(job)
|
||||
|
|
Loading…
Reference in New Issue