1
0
Fork 0

lock per worker, torch before ORT

This commit is contained in:
Sean Sube 2023-02-26 12:24:51 -06:00
parent d765a6f01b
commit e1d0ad54b7
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 15 additions and 8 deletions

View File

@ -2,6 +2,7 @@ from enum import IntEnum
from logging import getLogger
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
import torch
from onnxruntime import GraphOptimizationLevel, SessionOptions
logger = getLogger(__name__)

View File

@ -26,7 +26,8 @@ class DevicePoolExecutor:
self.devices = devices
self.finished = []
self.finished_limit = finished_limit
self.lock = Lock()
self.context = {}
self.locks = {}
self.pending = {}
self.progress = {}
self.workers = {}
@ -42,15 +43,18 @@ class DevicePoolExecutor:
# create a pending queue and progress value for each device
for device in devices:
name = device.device
cancel = Value("B", False, lock=self.lock)
progress = Value("I", 0, lock=self.lock)
lock = Lock()
self.locks[name] = lock
cancel = Value("B", False, lock=lock)
progress = Value("I", 0, lock=lock)
self.progress[name] = progress
pending = Queue()
context = WorkerContext(name, cancel, device, pending, progress)
self.pending[name] = pending
self.progress[name] = pending
context = WorkerContext(name, cancel, device, pending, progress)
self.context[name] = context
logger.debug("starting worker for device %s", device)
self.workers[name] = Process(target=worker_init, args=(self.lock, context))
self.workers[name] = Process(target=worker_init, args=(lock, context))
self.workers[name].start()
def cancel(self, key: str) -> bool:
@ -135,6 +139,7 @@ class DevicePoolExecutor:
(
device.device,
self.pending[device.device].qsize(),
self.progress[device.device].value,
self.workers[device.device].is_alive(),
)
for device in self.devices

View File

@ -1,7 +1,8 @@
from logging import getLogger
import torch # has to come before ORT
from onnxruntime import get_available_providers
from torch.multiprocessing import Lock, Queue
from traceback import print_exception
from traceback import format_exception
from .context import WorkerContext
@ -30,5 +31,5 @@ def worker_init(lock: Lock, context: WorkerContext):
fn(context, *args, **kwargs)
logger.info("finished job")
except Exception as e:
print_exception(type(e), e, e.__traceback__)
logger.error(format_exception(type(e), e, e.__traceback__))