lock per worker, torch before ORT
This commit is contained in:
parent
d765a6f01b
commit
e1d0ad54b7
|
@ -2,6 +2,7 @@ from enum import IntEnum
|
|||
from logging import getLogger
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from onnxruntime import GraphOptimizationLevel, SessionOptions
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
|
|
@ -26,7 +26,8 @@ class DevicePoolExecutor:
|
|||
self.devices = devices
|
||||
self.finished = []
|
||||
self.finished_limit = finished_limit
|
||||
self.lock = Lock()
|
||||
self.context = {}
|
||||
self.locks = {}
|
||||
self.pending = {}
|
||||
self.progress = {}
|
||||
self.workers = {}
|
||||
|
@ -42,15 +43,18 @@ class DevicePoolExecutor:
|
|||
# create a pending queue and progress value for each device
|
||||
for device in devices:
|
||||
name = device.device
|
||||
cancel = Value("B", False, lock=self.lock)
|
||||
progress = Value("I", 0, lock=self.lock)
|
||||
lock = Lock()
|
||||
self.locks[name] = lock
|
||||
cancel = Value("B", False, lock=lock)
|
||||
progress = Value("I", 0, lock=lock)
|
||||
self.progress[name] = progress
|
||||
pending = Queue()
|
||||
context = WorkerContext(name, cancel, device, pending, progress)
|
||||
self.pending[name] = pending
|
||||
self.progress[name] = pending
|
||||
context = WorkerContext(name, cancel, device, pending, progress)
|
||||
self.context[name] = context
|
||||
|
||||
logger.debug("starting worker for device %s", device)
|
||||
self.workers[name] = Process(target=worker_init, args=(self.lock, context))
|
||||
self.workers[name] = Process(target=worker_init, args=(lock, context))
|
||||
self.workers[name].start()
|
||||
|
||||
def cancel(self, key: str) -> bool:
|
||||
|
@ -135,6 +139,7 @@ class DevicePoolExecutor:
|
|||
(
|
||||
device.device,
|
||||
self.pending[device.device].qsize(),
|
||||
self.progress[device.device].value,
|
||||
self.workers[device.device].is_alive(),
|
||||
)
|
||||
for device in self.devices
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from logging import getLogger
|
||||
import torch # has to come before ORT
|
||||
from onnxruntime import get_available_providers
|
||||
from torch.multiprocessing import Lock, Queue
|
||||
from traceback import print_exception
|
||||
from traceback import format_exception
|
||||
|
||||
from .context import WorkerContext
|
||||
|
||||
|
@ -30,5 +31,5 @@ def worker_init(lock: Lock, context: WorkerContext):
|
|||
fn(context, *args, **kwargs)
|
||||
logger.info("finished job")
|
||||
except Exception as e:
|
||||
print_exception(type(e), e, e.__traceback__)
|
||||
logger.error(format_exception(type(e), e, e.__traceback__))
|
||||
|
||||
|
|
Loading…
Reference in New Issue