2023-02-04 22:22:50 +00:00
|
|
|
from collections import Counter
|
2023-02-05 13:53:26 +00:00
|
|
|
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
2023-02-04 16:06:22 +00:00
|
|
|
from logging import getLogger
|
|
|
|
from multiprocessing import Value
|
2023-02-05 14:50:26 +00:00
|
|
|
from traceback import format_exception
|
2023-02-04 22:57:00 +00:00
|
|
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
2023-02-04 16:06:22 +00:00
|
|
|
|
2023-02-14 00:04:46 +00:00
|
|
|
from ..params import DeviceParams
|
|
|
|
from ..utils import run_gc
|
2023-02-04 19:49:34 +00:00
|
|
|
|
2023-02-04 16:06:22 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-02-12 18:17:36 +00:00
|
|
|
ProgressCallback = Callable[[int, int, Any], None]
|
|
|
|
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
class JobContext:
|
2023-02-04 22:57:00 +00:00
|
|
|
cancel: Value = None
|
|
|
|
device_index: Value = None
|
2023-02-04 22:55:42 +00:00
|
|
|
devices: List[DeviceParams] = None
|
|
|
|
key: str = None
|
2023-02-04 22:57:00 +00:00
|
|
|
progress: Value = None
|
2023-02-04 22:55:42 +00:00
|
|
|
|
2023-02-04 16:06:22 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
key: str,
|
2023-02-04 19:49:34 +00:00
|
|
|
devices: List[DeviceParams],
|
2023-02-04 16:06:22 +00:00
|
|
|
cancel: bool = False,
|
|
|
|
device_index: int = -1,
|
|
|
|
progress: int = 0,
|
|
|
|
):
|
|
|
|
self.key = key
|
|
|
|
self.devices = list(devices)
|
2023-02-05 13:53:26 +00:00
|
|
|
self.cancel = Value("B", cancel)
|
|
|
|
self.device_index = Value("i", device_index)
|
|
|
|
self.progress = Value("I", progress)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
def is_cancelled(self) -> bool:
|
|
|
|
return self.cancel.value
|
|
|
|
|
2023-02-04 19:49:34 +00:00
|
|
|
def get_device(self) -> DeviceParams:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-04 16:06:22 +00:00
|
|
|
Get the device assigned to this job.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-04 16:06:22 +00:00
|
|
|
with self.device_index.get_lock():
|
|
|
|
device_index = self.device_index.value
|
|
|
|
if device_index < 0:
|
2023-02-05 13:53:26 +00:00
|
|
|
raise Exception("job has not been assigned to a device")
|
2023-02-04 16:06:22 +00:00
|
|
|
else:
|
2023-02-04 16:50:41 +00:00
|
|
|
device = self.devices[device_index]
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.debug("job %s assigned to device %s", self.key, device)
|
2023-02-04 16:50:41 +00:00
|
|
|
return device
|
2023-02-04 16:06:22 +00:00
|
|
|
|
2023-02-04 16:16:30 +00:00
|
|
|
def get_progress(self) -> int:
|
|
|
|
return self.progress.value
|
|
|
|
|
2023-02-12 18:17:36 +00:00
|
|
|
def get_progress_callback(self) -> ProgressCallback:
|
2023-02-04 16:06:22 +00:00
|
|
|
def on_progress(step: int, timestep: int, latents: Any):
|
2023-02-12 18:17:36 +00:00
|
|
|
on_progress.step = step
|
2023-02-04 16:06:22 +00:00
|
|
|
if self.is_cancelled():
|
2023-02-05 13:53:26 +00:00
|
|
|
raise Exception("job has been cancelled")
|
2023-02-04 16:06:22 +00:00
|
|
|
else:
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.debug("setting progress for job %s to %s", self.key, step)
|
2023-02-04 16:06:22 +00:00
|
|
|
self.set_progress(step)
|
|
|
|
|
|
|
|
return on_progress
|
|
|
|
|
|
|
|
def set_cancel(self, cancel: bool = True) -> None:
|
|
|
|
with self.cancel.get_lock():
|
|
|
|
self.cancel.value = cancel
|
|
|
|
|
|
|
|
def set_progress(self, progress: int) -> None:
|
|
|
|
with self.progress.get_lock():
|
|
|
|
self.progress.value = progress
|
|
|
|
|
|
|
|
|
|
|
|
class Job:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-04 17:08:22 +00:00
|
|
|
Link a future to its context.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-04 19:49:34 +00:00
|
|
|
|
2023-02-04 22:55:42 +00:00
|
|
|
context: JobContext = None
|
|
|
|
future: Future = None
|
|
|
|
key: str = None
|
|
|
|
|
2023-02-04 16:06:22 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
key: str,
|
|
|
|
future: Future,
|
|
|
|
context: JobContext,
|
|
|
|
):
|
|
|
|
self.context = context
|
|
|
|
self.future = future
|
|
|
|
self.key = key
|
|
|
|
|
2023-02-04 16:16:30 +00:00
|
|
|
def get_progress(self) -> int:
|
2023-02-04 17:08:22 +00:00
|
|
|
return self.context.get_progress()
|
2023-02-04 16:16:30 +00:00
|
|
|
|
2023-02-04 16:06:22 +00:00
|
|
|
def set_cancel(self, cancel: bool = True):
|
2023-02-04 17:08:22 +00:00
|
|
|
return self.context.set_cancel(cancel)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
def set_progress(self, progress: int):
|
2023-02-04 17:08:22 +00:00
|
|
|
return self.context.set_progress(progress)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
class DevicePoolExecutor:
|
2023-02-04 19:52:32 +00:00
|
|
|
devices: List[DeviceParams] = None
|
2023-02-04 16:06:22 +00:00
|
|
|
jobs: List[Job] = None
|
2023-02-04 21:49:05 +00:00
|
|
|
next_device: int = 0
|
2023-02-04 16:06:22 +00:00
|
|
|
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
devices: List[DeviceParams],
|
|
|
|
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None,
|
|
|
|
):
|
2023-02-04 16:06:22 +00:00
|
|
|
self.devices = devices
|
|
|
|
self.jobs = []
|
2023-02-04 21:49:05 +00:00
|
|
|
self.next_device = 0
|
2023-02-04 18:29:30 +00:00
|
|
|
|
|
|
|
device_count = len(devices)
|
|
|
|
if pool is None:
|
2023-02-04 19:49:34 +00:00
|
|
|
logger.info(
|
2023-02-05 13:53:26 +00:00
|
|
|
"creating thread pool executor for %s devices: %s",
|
|
|
|
device_count,
|
|
|
|
[d.device for d in devices],
|
|
|
|
)
|
2023-02-04 18:29:30 +00:00
|
|
|
self.pool = ThreadPoolExecutor(device_count)
|
|
|
|
else:
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.info(
|
|
|
|
"using existing pool for %s devices: %s",
|
|
|
|
device_count,
|
|
|
|
[d.device for d in devices],
|
|
|
|
)
|
2023-02-04 18:29:30 +00:00
|
|
|
self.pool = pool
|
2023-02-04 16:06:22 +00:00
|
|
|
|
|
|
|
def cancel(self, key: str) -> bool:
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-04 16:06:22 +00:00
|
|
|
Cancel a job. If the job has not been started, this will cancel
|
|
|
|
the future and never execute it. If the job has been started, it
|
|
|
|
should be cancelled on the next progress callback.
|
2023-02-05 13:53:26 +00:00
|
|
|
"""
|
2023-02-04 16:06:22 +00:00
|
|
|
for job in self.jobs:
|
|
|
|
if job.key == key:
|
|
|
|
if job.future.cancel():
|
|
|
|
return True
|
|
|
|
else:
|
2023-02-04 16:10:40 +00:00
|
|
|
job.set_cancel()
|
2023-02-05 03:17:39 +00:00
|
|
|
return True
|
|
|
|
|
|
|
|
return False
|
2023-02-04 16:06:22 +00:00
|
|
|
|
2023-02-06 23:13:37 +00:00
|
|
|
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
2023-02-04 16:06:22 +00:00
|
|
|
for job in self.jobs:
|
|
|
|
if job.key == key:
|
2023-02-04 16:16:30 +00:00
|
|
|
done = job.future.done()
|
|
|
|
progress = job.get_progress()
|
|
|
|
return (done, progress)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.warn("checking status for unknown key: %s", key)
|
2023-02-04 16:16:30 +00:00
|
|
|
return (None, 0)
|
2023-02-04 16:06:22 +00:00
|
|
|
|
2023-02-11 21:41:42 +00:00
|
|
|
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
|
|
|
# respect overrides if possible
|
|
|
|
if needs_device is not None:
|
2023-02-11 22:19:08 +00:00
|
|
|
for i in range(len(self.devices)):
|
2023-02-11 21:41:42 +00:00
|
|
|
if self.devices[i].device == needs_device.device:
|
|
|
|
return i
|
|
|
|
|
2023-02-04 22:37:36 +00:00
|
|
|
# use the first/default device if there are no jobs
|
|
|
|
if len(self.jobs) == 0:
|
|
|
|
return 0
|
|
|
|
|
2023-02-05 13:53:26 +00:00
|
|
|
job_devices = [
|
|
|
|
job.context.device_index.value for job in self.jobs if not job.future.done()
|
|
|
|
]
|
2023-02-04 22:37:36 +00:00
|
|
|
job_counts = Counter(range(len(self.devices)))
|
|
|
|
job_counts.update(job_devices)
|
|
|
|
|
|
|
|
queued = job_counts.most_common()
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.debug("jobs queued by device: %s", queued)
|
2023-02-04 22:22:50 +00:00
|
|
|
|
2023-02-04 23:02:52 +00:00
|
|
|
lowest_count = queued[-1][1]
|
|
|
|
lowest_devices = [d[0] for d in queued if d[1] == lowest_count]
|
|
|
|
lowest_devices.sort()
|
|
|
|
|
|
|
|
return lowest_devices[0]
|
2023-02-04 21:49:05 +00:00
|
|
|
|
2023-02-04 16:06:22 +00:00
|
|
|
def prune(self):
|
|
|
|
self.jobs[:] = [job for job in self.jobs if job.future.done()]
|
|
|
|
|
2023-02-11 22:50:57 +00:00
|
|
|
def submit(
|
|
|
|
self,
|
|
|
|
key: str,
|
|
|
|
fn: Callable[..., None],
|
|
|
|
/,
|
|
|
|
*args,
|
|
|
|
needs_device: Optional[DeviceParams] = None,
|
|
|
|
**kwargs,
|
|
|
|
) -> None:
|
2023-02-11 21:41:42 +00:00
|
|
|
device = self.get_next_device(needs_device=needs_device)
|
2023-02-11 22:50:57 +00:00
|
|
|
logger.info(
|
|
|
|
"assigning job %s to device %s: %s", key, device, self.devices[device]
|
|
|
|
)
|
2023-02-04 21:56:32 +00:00
|
|
|
|
|
|
|
context = JobContext(key, self.devices, device_index=device)
|
2023-02-04 16:06:22 +00:00
|
|
|
future = self.pool.submit(fn, context, *args, **kwargs)
|
|
|
|
job = Job(key, future, context)
|
|
|
|
self.jobs.append(job)
|
2023-02-04 16:59:03 +00:00
|
|
|
|
2023-02-04 17:56:34 +00:00
|
|
|
def job_done(f: Future):
|
|
|
|
try:
|
|
|
|
f.result()
|
2023-02-05 13:53:26 +00:00
|
|
|
logger.info("job %s finished successfully", key)
|
2023-02-04 17:56:34 +00:00
|
|
|
except Exception as err:
|
2023-02-05 16:50:15 +00:00
|
|
|
logger.warn(
|
|
|
|
"job %s failed with an error: %s",
|
|
|
|
key,
|
|
|
|
format_exception(type(err), err, err.__traceback__),
|
|
|
|
)
|
2023-02-10 04:04:33 +00:00
|
|
|
run_gc()
|
2023-02-04 17:56:34 +00:00
|
|
|
|
|
|
|
future.add_done_callback(job_done)
|
|
|
|
|
2023-02-04 22:55:20 +00:00
|
|
|
def status(self) -> List[Tuple[str, int, bool, int]]:
|
2023-02-05 13:53:26 +00:00
|
|
|
return [
|
|
|
|
(
|
|
|
|
job.key,
|
|
|
|
job.context.device_index.value,
|
|
|
|
job.future.done(),
|
|
|
|
job.get_progress(),
|
|
|
|
)
|
|
|
|
for job in self.jobs
|
|
|
|
]
|