1
0
Fork 0

begin switching to per-device torch mp workers

This commit is contained in:
Sean Sube 2023-02-25 23:16:32 -06:00
parent e03b637f54
commit e46a1e5fd0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 53 additions and 98 deletions

View File

@ -1,9 +1,11 @@
from collections import Counter from collections import Counter
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import Future
from logging import getLogger from logging import getLogger
from multiprocessing import Value from multiprocessing import Queue
from torch.multiprocessing import Lock, Process, SimpleQueue, Value
from traceback import format_exception from traceback import format_exception
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from time import sleep
from ..params import DeviceParams from ..params import DeviceParams
from ..utils import run_gc from ..utils import run_gc
@ -13,6 +15,18 @@ logger = getLogger(__name__)
ProgressCallback = Callable[[int, int, Any], None] ProgressCallback = Callable[[int, int, Any], None]
def worker_init(lock: Lock, job_queue: SimpleQueue):
logger.info("checking in from worker")
while True:
if job_queue.empty():
logger.info("no jobs, sleeping")
sleep(5)
else:
job = job_queue.get()
logger.info("got job: %s", job)
class JobContext: class JobContext:
cancel: Value = None cancel: Value = None
device_index: Value = None device_index: Value = None
@ -104,38 +118,31 @@ class Job:
class DevicePoolExecutor: class DevicePoolExecutor:
devices: List[DeviceParams] = None devices: List[DeviceParams] = None
jobs: List[Job] = None finished: List[Tuple[str, int]] = None
next_device: int = 0 pending: Dict[str, "Queue[Job]"] = None
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None progress: Dict[str, Value] = None
recent: List[Tuple[str, int]] = None workers: Dict[str, Process] = None
def __init__( def __init__(
self, self,
devices: List[DeviceParams], devices: List[DeviceParams],
pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]] = None, finished_limit: int = 10,
recent_limit: int = 10,
): ):
self.devices = devices self.devices = devices
self.jobs = [] self.finished = []
self.next_device = 0 self.finished_limit = finished_limit
self.recent = [] self.lock = Lock()
self.recent_limit = recent_limit self.pending = {}
self.progress = {}
self.workers = {}
device_count = len(devices) # create a pending queue and progress value for each device
if pool is None: for device in devices:
logger.info( name = device.device
"creating thread pool executor for %s devices: %s", job_queue = Queue()
device_count, self.pending[name] = job_queue
[d.device for d in devices], self.progress[name] = Value("I", 0, lock=self.lock)
) self.workers[name] = Process(target=worker_init, args=(self.lock, job_queue))
self.pool = ThreadPoolExecutor(device_count)
else:
logger.info(
"using existing pool for %s devices: %s",
device_count,
[d.device for d in devices],
)
self.pool = pool
def cancel(self, key: str) -> bool: def cancel(self, key: str) -> bool:
""" """
@ -143,31 +150,13 @@ class DevicePoolExecutor:
the future and never execute it. If the job has been started, it the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback. should be cancelled on the next progress callback.
""" """
for job in self.jobs: raise NotImplementedError()
if job.key == key:
if job.future.cancel():
return True
else:
job.set_cancel()
return True
return False
def done(self, key: str) -> Tuple[Optional[bool], int]: def done(self, key: str) -> Tuple[Optional[bool], int]:
for k, progress in self.recent: for k, progress in self.finished:
if key == k: if key == k:
return (True, progress) return (True, progress)
for job in self.jobs:
if job.key == key:
done = job.future.done()
progress = job.get_progress()
if done:
self.prune()
return (done, progress)
logger.warn("checking status for unknown key: %s", key) logger.warn("checking status for unknown key: %s", key)
return (None, 0) return (None, 0)
@ -198,24 +187,14 @@ class DevicePoolExecutor:
return lowest_devices[0] return lowest_devices[0]
def prune(self): def prune(self):
pending_jobs = [job for job in self.jobs if job.future.done()] finished_count = len(self.finished)
logger.debug("pruning %s of %s pending jobs", len(pending_jobs), len(self.jobs)) if finished_count > self.finished_limit:
for job in pending_jobs:
self.recent.append((job.key, job.get_progress()))
try:
self.jobs.remove(job)
except ValueError as e:
logger.warning("error removing pruned job from pending: %s", e)
recent_count = len(self.recent)
if recent_count > self.recent_limit:
logger.debug( logger.debug(
"pruning %s of %s recent jobs", "pruning %s of %s finished jobs",
recent_count - self.recent_limit, finished_count - self.finished_limit,
recent_count, finished_count,
) )
self.recent[:] = self.recent[-self.recent_limit :] self.finished[:] = self.finished[-self.finished_limit:]
def submit( def submit(
self, self,
@ -227,49 +206,25 @@ class DevicePoolExecutor:
**kwargs, **kwargs,
) -> None: ) -> None:
self.prune() self.prune()
device = self.get_next_device(needs_device=needs_device) device_idx = self.get_next_device(needs_device=needs_device)
logger.info( logger.info(
"assigning job %s to device %s: %s", key, device, self.devices[device] "assigning job %s to device %s: %s", key, device_idx, self.devices[device_idx]
) )
context = JobContext(key, self.devices, device_index=device) context = JobContext(key, self.devices, device_index=device_idx)
future = self.pool.submit(fn, context, *args, **kwargs) device = self.devices[device_idx]
job = Job(key, future, context)
self.jobs.append(job)
def job_done(f: Future): queue = self.pending[device.device]
try: queue.put((fn, context, args, kwargs))
f.result()
logger.info("job %s finished successfully", key)
except Exception as err:
logger.warn(
"job %s failed with an error: %s",
key,
format_exception(type(err), err, err.__traceback__),
)
run_gc([self.devices[device]])
future.add_done_callback(job_done)
def status(self) -> List[Tuple[str, int, bool, int]]: def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [ pending = [
( (
job.key, device.device,
job.context.device_index.value, self.pending[device.device].qsize(),
job.future.done(),
job.get_progress(),
) )
for job in self.jobs for device in self.devices
] ]
recent = [ pending.extend(self.finished)
(
key,
None,
True,
progress,
)
for key, progress in self.recent
]
pending.extend(recent)
return pending return pending