2023-02-04 16:06:22 +00:00
|
|
|
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
|
|
|
|
from logging import getLogger
|
|
|
|
from multiprocessing import Value
|
2023-02-04 16:59:03 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class JobContext:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
key: str,
|
|
|
|
devices: List[str],
|
|
|
|
cancel: bool = False,
|
|
|
|
device_index: int = -1,
|
|
|
|
progress: int = 0,
|
|
|
|
):
|
|
|
|
self.key = key
|
|
|
|
self.devices = list(devices)
|
|
|
|
self.cancel = Value('B', cancel)
|
|
|
|
self.device_index = Value('i', device_index)
|
|
|
|
self.progress = Value('I', progress)
|
|
|
|
|
|
|
|
def is_cancelled(self) -> bool:
|
|
|
|
return self.cancel.value
|
|
|
|
|
|
|
|
def get_device(self) -> str:
|
|
|
|
'''
|
|
|
|
Get the device assigned to this job.
|
|
|
|
'''
|
|
|
|
with self.device_index.get_lock():
|
|
|
|
device_index = self.device_index.value
|
|
|
|
if device_index < 0:
|
|
|
|
raise Exception('job has not been assigned to a device')
|
|
|
|
else:
|
2023-02-04 16:50:41 +00:00
|
|
|
device = self.devices[device_index]
|
|
|
|
logger.debug('job %s assigned to device %s', self.key, device)
|
|
|
|
return device
|
2023-02-04 16:06:22 +00:00
|
|
|
|
2023-02-04 16:16:30 +00:00
|
|
|
def get_progress(self) -> int:
|
|
|
|
return self.progress.value
|
|
|
|
|
2023-02-04 16:06:22 +00:00
|
|
|
def get_progress_callback(self) -> Callable[..., None]:
|
|
|
|
def on_progress(step: int, timestep: int, latents: Any):
|
|
|
|
if self.is_cancelled():
|
|
|
|
raise Exception('job has been cancelled')
|
|
|
|
else:
|
|
|
|
self.set_progress(step)
|
|
|
|
|
|
|
|
return on_progress
|
|
|
|
|
|
|
|
def set_cancel(self, cancel: bool = True) -> None:
|
|
|
|
with self.cancel.get_lock():
|
|
|
|
self.cancel.value = cancel
|
|
|
|
|
|
|
|
def set_progress(self, progress: int) -> None:
|
|
|
|
with self.progress.get_lock():
|
|
|
|
self.progress.value = progress
|
2023-02-04 16:50:41 +00:00
|
|
|
logger.debug('setting progress for job %s to %s', self.key, progress)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Job:
|
2023-02-04 17:08:22 +00:00
|
|
|
'''
|
|
|
|
Link a future to its context.
|
|
|
|
'''
|
2023-02-04 16:06:22 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
key: str,
|
|
|
|
future: Future,
|
|
|
|
context: JobContext,
|
|
|
|
):
|
|
|
|
self.context = context
|
|
|
|
self.future = future
|
|
|
|
self.key = key
|
|
|
|
|
2023-02-04 16:16:30 +00:00
|
|
|
def get_progress(self) -> int:
|
2023-02-04 17:08:22 +00:00
|
|
|
return self.context.get_progress()
|
2023-02-04 16:16:30 +00:00
|
|
|
|
2023-02-04 16:06:22 +00:00
|
|
|
def set_cancel(self, cancel: bool = True):
|
2023-02-04 17:08:22 +00:00
|
|
|
return self.context.set_cancel(cancel)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
def set_progress(self, progress: int):
|
2023-02-04 17:08:22 +00:00
|
|
|
return self.context.set_progress(progress)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
class DevicePoolExecutor:
|
|
|
|
devices: List[str] = None
|
|
|
|
jobs: List[Job] = None
|
|
|
|
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
|
|
|
|
2023-02-04 16:07:58 +00:00
|
|
|
def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None):
|
2023-02-04 16:06:22 +00:00
|
|
|
self.devices = devices
|
|
|
|
self.jobs = []
|
|
|
|
self.pool = pool or ThreadPoolExecutor(len(devices))
|
|
|
|
|
|
|
|
def cancel(self, key: str) -> bool:
|
|
|
|
'''
|
|
|
|
Cancel a job. If the job has not been started, this will cancel
|
|
|
|
the future and never execute it. If the job has been started, it
|
|
|
|
should be cancelled on the next progress callback.
|
|
|
|
'''
|
|
|
|
for job in self.jobs:
|
|
|
|
if job.key == key:
|
|
|
|
if job.future.cancel():
|
|
|
|
return True
|
|
|
|
else:
|
2023-02-04 16:10:40 +00:00
|
|
|
job.set_cancel()
|
2023-02-04 16:06:22 +00:00
|
|
|
|
2023-02-04 16:16:30 +00:00
|
|
|
def done(self, key: str) -> Tuple[bool, int]:
|
2023-02-04 16:06:22 +00:00
|
|
|
for job in self.jobs:
|
|
|
|
if job.key == key:
|
2023-02-04 16:16:30 +00:00
|
|
|
done = job.future.done()
|
|
|
|
progress = job.get_progress()
|
|
|
|
return (done, progress)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
logger.warn('checking status for unknown key: %s', key)
|
2023-02-04 16:16:30 +00:00
|
|
|
return (None, 0)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
def prune(self):
|
|
|
|
self.jobs[:] = [job for job in self.jobs if job.future.done()]
|
|
|
|
|
|
|
|
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
|
|
|
|
context = JobContext(key, self.devices, device_index=0)
|
|
|
|
future = self.pool.submit(fn, context, *args, **kwargs)
|
|
|
|
job = Job(key, future, context)
|
|
|
|
self.jobs.append(job)
|
2023-02-04 16:59:03 +00:00
|
|
|
|
|
|
|
def status(self) -> Dict[str, Tuple[bool, int]]:
|
2023-02-04 17:08:22 +00:00
|
|
|
return [(job.key, job.future.done(), job.get_progress()) for job in self.jobs]
|