1
0
Fork 0

begin switching to per-device torch mp workers

This commit is contained in:
Sean Sube 2023-02-25 23:16:32 -06:00
parent e03b637f54
commit e46a1e5fd0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 53 additions and 98 deletions

View File

@ -1,9 +1,11 @@
from collections import Counter
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures import Future
from logging import getLogger
from multiprocessing import Value
from multiprocessing import Queue
from torch.multiprocessing import Lock, Process, SimpleQueue, Value
from traceback import format_exception
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from time import sleep
from ..params import DeviceParams
from ..utils import run_gc
@ -13,6 +15,18 @@ logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None]
def worker_init(lock: Lock, job_queue: SimpleQueue):
logger.info("checking in from worker")
while True:
if job_queue.empty():
logger.info("no jobs, sleeping")
sleep(5)
else:
job = job_queue.get()
logger.info("got job: %s", job)
class JobContext:
cancel: Value = None
device_index: Value = None
@ -104,38 +118,31 @@ class Job:
class DevicePoolExecutor:
devices: List[DeviceParams] = None
jobs: List[Job] = None
next_device: int = 0
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
recent: List[Tuple[str, int]] = None
finished: List[Tuple[str, int]] = None
pending: Dict[str, "Queue[Job]"] = None
progress: Dict[str, Value] = None
workers: Dict[str, Process] = None
def __init__(
self,
devices: List[DeviceParams],
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None,
recent_limit: int = 10,
finished_limit: int = 10,
):
self.devices = devices
self.jobs = []
self.next_device = 0
self.recent = []
self.recent_limit = recent_limit
self.finished = []
self.finished_limit = finished_limit
self.lock = Lock()
self.pending = {}
self.progress = {}
self.workers = {}
device_count = len(devices)
if pool is None:
logger.info(
"creating thread pool executor for %s devices: %s",
device_count,
[d.device for d in devices],
)
self.pool = ThreadPoolExecutor(device_count)
else:
logger.info(
"using existing pool for %s devices: %s",
device_count,
[d.device for d in devices],
)
self.pool = pool
# create a pending queue and progress value for each device
for device in devices:
name = device.device
job_queue = Queue()
self.pending[name] = job_queue
self.progress[name] = Value("I", 0, lock=self.lock)
self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue))
def cancel(self, key: str) -> bool:
"""
@ -143,31 +150,13 @@ class DevicePoolExecutor:
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
for job in self.jobs:
if job.key == key:
if job.future.cancel():
return True
else:
job.set_cancel()
return True
return False
raise NotImplementedError()
def done(self, key: str) -> Tuple[Optional[bool], int]:
for k, progress in self.recent:
for k, progress in self.finished:
if key == k:
return (True, progress)
for job in self.jobs:
if job.key == key:
done = job.future.done()
progress = job.get_progress()
if done:
self.prune()
return (done, progress)
logger.warn("checking status for unknown key: %s", key)
return (None, 0)
@ -198,24 +187,14 @@ class DevicePoolExecutor:
return lowest_devices[0]
def prune(self):
pending_jobs = [job for job in self.jobs if job.future.done()]
logger.debug("pruning %s of %s pending jobs", len(pending_jobs), len(self.jobs))
for job in pending_jobs:
self.recent.append((job.key, job.get_progress()))
try:
self.jobs.remove(job)
except ValueError as e:
logger.warning("error removing pruned job from pending: %s", e)
recent_count = len(self.recent)
if recent_count > self.recent_limit:
finished_count = len(self.finished)
if finished_count > self.finished_limit:
logger.debug(
"pruning %s of %s recent jobs",
recent_count - self.recent_limit,
recent_count,
"pruning %s of %s finished jobs",
finished_count - self.finished_limit,
finished_count,
)
self.recent[:] = self.recent[-self.recent_limit :]
self.finished[:] = self.finished[-self.finished_limit:]
def submit(
self,
@ -227,49 +206,25 @@ class DevicePoolExecutor:
**kwargs,
) -> None:
self.prune()
device = self.get_next_device(needs_device=needs_device)
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
"assigning job %s to device %s: %s", key, device, self.devices[device]
"assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
)
context = JobContext(key, self.devices, device_index=device)
future = self.pool.submit(fn, context, *args, **kwargs)
job = Job(key, future, context)
self.jobs.append(job)
context = JobContext(key, self.devices, device_index=device_idx)
device = self.devices[device_idx]
def job_done(f: Future):
try:
f.result()
logger.info("job %s finished successfully", key)
except Exception as err:
logger.warn(
"job %s failed with an error: %s",
key,
format_exception(type(err), err, err.__traceback__),
)
run_gc([self.devices[device]])
queue = self.pending[device.device]
queue.put((fn, context, args, kwargs))
future.add_done_callback(job_done)
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
(
job.key,
job.context.device_index.value,
job.future.done(),
job.get_progress(),
device.device,
self.pending[device.device].qsize(),
)
for job in self.jobs
for device in self.devices
]
recent = [
(
key,
None,
True,
progress,
)
for key, progress in self.recent
]
pending.extend(recent)
pending.extend(self.finished)
return pending