begin switching to per-device torch mp workers
This commit is contained in:
parent
e03b637f54
commit
e46a1e5fd0
|
@ -1,9 +1,11 @@
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
from concurrent.futures import Future
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from multiprocessing import Value
|
from multiprocessing import Queue
|
||||||
|
from torch.multiprocessing import Lock, Process, SimpleQueue, Value
|
||||||
from traceback import format_exception
|
from traceback import format_exception
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
from ..params import DeviceParams
|
from ..params import DeviceParams
|
||||||
from ..utils import run_gc
|
from ..utils import run_gc
|
||||||
|
@ -13,6 +15,18 @@ logger = getLogger(__name__)
|
||||||
ProgressCallback = Callable[[int, int, Any], None]
|
ProgressCallback = Callable[[int, int, Any], None]
|
||||||
|
|
||||||
|
|
||||||
|
def worker_init(lock: Lock, job_queue: SimpleQueue):
|
||||||
|
logger.info("checking in from worker")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if job_queue.empty():
|
||||||
|
logger.info("no jobs, sleeping")
|
||||||
|
sleep(5)
|
||||||
|
else:
|
||||||
|
job = job_queue.get()
|
||||||
|
logger.info("got job: %s", job)
|
||||||
|
|
||||||
|
|
||||||
class JobContext:
|
class JobContext:
|
||||||
cancel: Value = None
|
cancel: Value = None
|
||||||
device_index: Value = None
|
device_index: Value = None
|
||||||
|
@ -104,38 +118,31 @@ class Job:
|
||||||
|
|
||||||
class DevicePoolExecutor:
|
class DevicePoolExecutor:
|
||||||
devices: List[DeviceParams] = None
|
devices: List[DeviceParams] = None
|
||||||
jobs: List[Job] = None
|
finished: List[Tuple[str, int]] = None
|
||||||
next_device: int = 0
|
pending: Dict[str, "Queue[Job]"] = None
|
||||||
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
progress: Dict[str, Value] = None
|
||||||
recent: List[Tuple[str, int]] = None
|
workers: Dict[str, Process] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
devices: List[DeviceParams],
|
devices: List[DeviceParams],
|
||||||
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None,
|
finished_limit: int = 10,
|
||||||
recent_limit: int = 10,
|
|
||||||
):
|
):
|
||||||
self.devices = devices
|
self.devices = devices
|
||||||
self.jobs = []
|
self.finished = []
|
||||||
self.next_device = 0
|
self.finished_limit = finished_limit
|
||||||
self.recent = []
|
self.lock = Lock()
|
||||||
self.recent_limit = recent_limit
|
self.pending = {}
|
||||||
|
self.progress = {}
|
||||||
|
self.workers = {}
|
||||||
|
|
||||||
device_count = len(devices)
|
# create a pending queue and progress value for each device
|
||||||
if pool is None:
|
for device in devices:
|
||||||
logger.info(
|
name = device.device
|
||||||
"creating thread pool executor for %s devices: %s",
|
job_queue = Queue()
|
||||||
device_count,
|
self.pending[name] = job_queue
|
||||||
[d.device for d in devices],
|
self.progress[name] = Value("I", 0, lock=self.lock)
|
||||||
)
|
self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue))
|
||||||
self.pool = ThreadPoolExecutor(device_count)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"using existing pool for %s devices: %s",
|
|
||||||
device_count,
|
|
||||||
[d.device for d in devices],
|
|
||||||
)
|
|
||||||
self.pool = pool
|
|
||||||
|
|
||||||
def cancel(self, key: str) -> bool:
|
def cancel(self, key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -143,31 +150,13 @@ class DevicePoolExecutor:
|
||||||
the future and never execute it. If the job has been started, it
|
the future and never execute it. If the job has been started, it
|
||||||
should be cancelled on the next progress callback.
|
should be cancelled on the next progress callback.
|
||||||
"""
|
"""
|
||||||
for job in self.jobs:
|
raise NotImplementedError()
|
||||||
if job.key == key:
|
|
||||||
if job.future.cancel():
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
job.set_cancel()
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||||
for k, progress in self.recent:
|
for k, progress in self.finished:
|
||||||
if key == k:
|
if key == k:
|
||||||
return (True, progress)
|
return (True, progress)
|
||||||
|
|
||||||
for job in self.jobs:
|
|
||||||
if job.key == key:
|
|
||||||
done = job.future.done()
|
|
||||||
progress = job.get_progress()
|
|
||||||
|
|
||||||
if done:
|
|
||||||
self.prune()
|
|
||||||
|
|
||||||
return (done, progress)
|
|
||||||
|
|
||||||
logger.warn("checking status for unknown key: %s", key)
|
logger.warn("checking status for unknown key: %s", key)
|
||||||
return (None, 0)
|
return (None, 0)
|
||||||
|
|
||||||
|
@ -198,24 +187,14 @@ class DevicePoolExecutor:
|
||||||
return lowest_devices[0]
|
return lowest_devices[0]
|
||||||
|
|
||||||
def prune(self):
|
def prune(self):
|
||||||
pending_jobs = [job for job in self.jobs if job.future.done()]
|
finished_count = len(self.finished)
|
||||||
logger.debug("pruning %s of %s pending jobs", len(pending_jobs), len(self.jobs))
|
if finished_count > self.finished_limit:
|
||||||
|
|
||||||
for job in pending_jobs:
|
|
||||||
self.recent.append((job.key, job.get_progress()))
|
|
||||||
try:
|
|
||||||
self.jobs.remove(job)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.warning("error removing pruned job from pending: %s", e)
|
|
||||||
|
|
||||||
recent_count = len(self.recent)
|
|
||||||
if recent_count > self.recent_limit:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"pruning %s of %s recent jobs",
|
"pruning %s of %s finished jobs",
|
||||||
recent_count - self.recent_limit,
|
finished_count - self.finished_limit,
|
||||||
recent_count,
|
finished_count,
|
||||||
)
|
)
|
||||||
self.recent[:] = self.recent[-self.recent_limit :]
|
self.finished[:] = self.finished[-self.finished_limit:]
|
||||||
|
|
||||||
def submit(
|
def submit(
|
||||||
self,
|
self,
|
||||||
|
@ -227,49 +206,25 @@ class DevicePoolExecutor:
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.prune()
|
self.prune()
|
||||||
device = self.get_next_device(needs_device=needs_device)
|
device_idx = self.get_next_device(needs_device=needs_device)
|
||||||
logger.info(
|
logger.info(
|
||||||
"assigning job %s to device %s: %s", key, device, self.devices[device]
|
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
|
||||||
)
|
)
|
||||||
|
|
||||||
context = JobContext(key, self.devices, device_index=device)
|
context = JobContext(key, self.devices, device_index=device_idx)
|
||||||
future = self.pool.submit(fn, context, *args, **kwargs)
|
device = self.devices[device_idx]
|
||||||
job = Job(key, future, context)
|
|
||||||
self.jobs.append(job)
|
|
||||||
|
|
||||||
def job_done(f: Future):
|
queue = self.pending[device.device]
|
||||||
try:
|
queue.put((fn, context, args, kwargs))
|
||||||
f.result()
|
|
||||||
logger.info("job %s finished successfully", key)
|
|
||||||
except Exception as err:
|
|
||||||
logger.warn(
|
|
||||||
"job %s failed with an error: %s",
|
|
||||||
key,
|
|
||||||
format_exception(type(err), err, err.__traceback__),
|
|
||||||
)
|
|
||||||
run_gc([self.devices[device]])
|
|
||||||
|
|
||||||
future.add_done_callback(job_done)
|
|
||||||
|
|
||||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||||
pending = [
|
pending = [
|
||||||
(
|
(
|
||||||
job.key,
|
device.device,
|
||||||
job.context.device_index.value,
|
self.pending[device.device].qsize(),
|
||||||
job.future.done(),
|
|
||||||
job.get_progress(),
|
|
||||||
)
|
)
|
||||||
for job in self.jobs
|
for device in self.devices
|
||||||
]
|
]
|
||||||
recent = [
|
pending.extend(self.finished)
|
||||||
(
|
|
||||||
key,
|
|
||||||
None,
|
|
||||||
True,
|
|
||||||
progress,
|
|
||||||
)
|
|
||||||
for key, progress in self.recent
|
|
||||||
]
|
|
||||||
|
|
||||||
pending.extend(recent)
|
|
||||||
return pending
|
return pending
|
||||||
|
|
Loading…
Reference in New Issue