clear worker flags between jobs, attempt to record finished jobs again
This commit is contained in:
parent
d1961afdbc
commit
85118d17c6
|
@ -20,12 +20,12 @@ class WorkerContext:
|
|||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
cancel: "Value[bool]",
|
||||
device: DeviceParams,
|
||||
pending: "Queue[Any]",
|
||||
progress: "Value[int]",
|
||||
logs: "Queue[str]",
|
||||
finished: "Value[bool]",
|
||||
cancel: "Value[bool]" = None,
|
||||
finished: "Value[bool]" = None,
|
||||
progress: "Value[int]" = None,
|
||||
logs: "Queue[str]" = None,
|
||||
pending: "Queue[Any]" = None,
|
||||
):
|
||||
self.key = key
|
||||
self.cancel = cancel
|
||||
|
@ -62,6 +62,15 @@ class WorkerContext:
|
|||
with self.cancel.get_lock():
|
||||
self.cancel.value = cancel
|
||||
|
||||
def set_finished(self, finished: bool = True) -> None:
|
||||
with self.finished.get_lock():
|
||||
self.finished.value = finished
|
||||
|
||||
def set_progress(self, progress: int) -> None:
|
||||
with self.progress.get_lock():
|
||||
self.progress.value = progress
|
||||
|
||||
def clear_flags(self) -> None:
|
||||
self.set_cancel(False)
|
||||
self.set_finished(False)
|
||||
self.set_progress(0)
|
||||
|
|
|
@ -3,7 +3,7 @@ from logging import getLogger
|
|||
from multiprocessing import Queue
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch.multiprocessing import Lock, Process, Value
|
||||
from torch.multiprocessing import Process, Value
|
||||
|
||||
from ..params import DeviceParams
|
||||
from ..server import ServerContext
|
||||
|
@ -14,67 +14,67 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
class DevicePoolExecutor:
|
||||
context: Dict[str, WorkerContext] = None
|
||||
devices: List[DeviceParams] = None
|
||||
finished: Dict[str, "Value[bool]"] = None
|
||||
pending: Dict[str, "Queue[WorkerContext]"] = None
|
||||
progress: Dict[str, "Value[int]"] = None
|
||||
workers: Dict[str, Process] = None
|
||||
jobs: Dict[str, str] = None
|
||||
active_job: Dict[str, str] = None
|
||||
finished: List[Tuple[str, int, bool]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server: ServerContext,
|
||||
devices: List[DeviceParams],
|
||||
finished_limit: int = 10,
|
||||
max_jobs_per_worker: int = 10,
|
||||
join_timeout: float = 5.0,
|
||||
):
|
||||
self.server = server
|
||||
self.devices = devices
|
||||
self.finished = {}
|
||||
self.finished_limit = finished_limit
|
||||
self.context = {}
|
||||
self.locks = {}
|
||||
self.pending = {}
|
||||
self.progress = {}
|
||||
self.workers = {}
|
||||
self.jobs = {} # Dict[Output, Device]
|
||||
self.job_count = 0
|
||||
self.max_jobs_per_worker = max_jobs_per_worker
|
||||
self.join_timeout = join_timeout
|
||||
|
||||
# TODO: make this a method
|
||||
logger.debug("starting log worker")
|
||||
self.log_queue = Queue()
|
||||
log_lock = Lock()
|
||||
self.locks["logger"] = log_lock
|
||||
self.logger = Process(target=logger_init, args=(log_lock, self.log_queue))
|
||||
self.logger.start()
|
||||
self.context = {}
|
||||
self.pending = {}
|
||||
self.workers = {}
|
||||
self.active_job = {}
|
||||
self.finished_jobs = 0 # TODO: turn this into a Dict per-worker
|
||||
|
||||
self.create_logger_worker()
|
||||
for device in devices:
|
||||
self.create_device_worker(device)
|
||||
|
||||
logger.debug("testing log worker")
|
||||
self.log_queue.put("testing")
|
||||
|
||||
# create a pending queue and progress value for each device
|
||||
for device in devices:
|
||||
name = device.device
|
||||
# TODO: make this a method
|
||||
lock = Lock()
|
||||
self.locks[name] = lock
|
||||
cancel = Value("B", False, lock=lock)
|
||||
finished = Value("B", False)
|
||||
self.finished[name] = finished
|
||||
progress = Value(
|
||||
"I", 0
|
||||
) # , lock=lock) # needs its own lock for some reason. TODO: why?
|
||||
self.progress[name] = progress
|
||||
pending = Queue()
|
||||
self.pending[name] = pending
|
||||
context = WorkerContext(
|
||||
name, cancel, device, pending, progress, self.log_queue, finished
|
||||
)
|
||||
self.context[name] = context
|
||||
def create_logger_worker(self) -> None:
|
||||
self.log_queue = Queue()
|
||||
self.logger = Process(target=logger_init, args=(self.log_queue))
|
||||
|
||||
logger.debug("starting worker for device %s", device)
|
||||
self.workers[name] = Process(
|
||||
target=worker_init, args=(lock, context, server)
|
||||
)
|
||||
self.workers[name].start()
|
||||
logger.debug("starting log worker")
|
||||
self.logger.start()
|
||||
|
||||
def create_device_worker(self, device: DeviceParams) -> None:
|
||||
name = device.device
|
||||
pending = Queue()
|
||||
self.pending[name] = pending
|
||||
context = WorkerContext(
|
||||
name,
|
||||
device,
|
||||
cancel=Value("B", False),
|
||||
finished=Value("B", False),
|
||||
progress=Value("I", 0),
|
||||
pending=pending,
|
||||
logs=self.log_queue,
|
||||
)
|
||||
self.context[name] = context
|
||||
self.workers[name] = Process(target=worker_init, args=(context, self.server))
|
||||
|
||||
logger.debug("starting worker for device %s", device)
|
||||
self.workers[name].start()
|
||||
|
||||
def create_prune_worker(self) -> None:
|
||||
# TODO: create a background thread to prune completed jobs
|
||||
pass
|
||||
|
||||
def cancel(self, key: str) -> bool:
|
||||
"""
|
||||
|
@ -82,29 +82,34 @@ class DevicePoolExecutor:
|
|||
the future and never execute it. If the job has been started, it
|
||||
should be cancelled on the next progress callback.
|
||||
"""
|
||||
if key not in self.jobs:
|
||||
if key not in self.active_job:
|
||||
logger.warn("attempting to cancel unknown job: %s", key)
|
||||
return False
|
||||
|
||||
device = self.jobs[key]
|
||||
cancel = self.context[device].cancel
|
||||
device = self.active_job[key]
|
||||
context = self.context[device]
|
||||
logger.info("cancelling job %s on device %s", key, device)
|
||||
|
||||
if cancel.get_lock():
|
||||
cancel.value = True
|
||||
if context.cancel.get_lock():
|
||||
context.cancel.value = True
|
||||
|
||||
# self.finished.append((key, context.progress.value, context.cancel.value)) maybe?
|
||||
return True
|
||||
|
||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||
if key not in self.jobs:
|
||||
if key not in self.active_job:
|
||||
logger.warn("checking status for unknown job: %s", key)
|
||||
return (None, 0)
|
||||
|
||||
device = self.jobs[key]
|
||||
finished = self.finished[device]
|
||||
progress = self.progress[device]
|
||||
# TODO: prune here, maybe?
|
||||
|
||||
return (finished.value, progress.value)
|
||||
device = self.active_job[key]
|
||||
context = self.context[device]
|
||||
|
||||
if context.finished.value is True:
|
||||
self.finished.append((key, context.progress.value, context.cancel.value))
|
||||
|
||||
return (context.finished.value, context.progress.value)
|
||||
|
||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||
# respect overrides if possible
|
||||
|
@ -113,9 +118,8 @@ class DevicePoolExecutor:
|
|||
if self.devices[i].device == needs_device.device:
|
||||
return i
|
||||
|
||||
pending = [self.pending[d.device].qsize() for d in self.devices]
|
||||
jobs = Counter(range(len(self.devices)))
|
||||
jobs.update(pending)
|
||||
jobs.update([self.pending[d.device].qsize() for d in self.devices])
|
||||
|
||||
queued = jobs.most_common()
|
||||
logger.debug("jobs queued by device: %s", queued)
|
||||
|
@ -130,26 +134,16 @@ class DevicePoolExecutor:
|
|||
for device, worker in self.workers.items():
|
||||
if worker.is_alive():
|
||||
logger.info("stopping worker for device %s", device)
|
||||
worker.join(5)
|
||||
worker.join(self.join_timeout)
|
||||
|
||||
if self.logger.is_alive():
|
||||
self.logger.join(5)
|
||||
|
||||
def prune(self):
|
||||
finished_count = len(self.finished)
|
||||
if finished_count > self.finished_limit:
|
||||
logger.debug(
|
||||
"pruning %s of %s finished jobs",
|
||||
finished_count - self.finished_limit,
|
||||
finished_count,
|
||||
)
|
||||
self.finished[:] = self.finished[-self.finished_limit :]
|
||||
self.logger.join(self.join_timeout)
|
||||
|
||||
def recycle(self):
|
||||
for name, proc in self.workers.items():
|
||||
if proc.is_alive():
|
||||
logger.debug("shutting down worker for device %s", name)
|
||||
proc.join(5)
|
||||
proc.join(self.join_timeout)
|
||||
proc.terminate()
|
||||
else:
|
||||
logger.warning("worker for device %s has died", name)
|
||||
|
@ -159,15 +153,8 @@ class DevicePoolExecutor:
|
|||
|
||||
logger.info("starting new workers")
|
||||
|
||||
for name in self.workers.keys():
|
||||
context = self.context[name]
|
||||
lock = self.locks[name]
|
||||
|
||||
logger.debug("starting worker for device %s", name)
|
||||
self.workers[name] = Process(
|
||||
target=worker_init, args=(lock, context, self.server)
|
||||
)
|
||||
self.workers[name].start()
|
||||
for device in self.devices:
|
||||
self.create_device_worker(device)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
|
@ -178,13 +165,12 @@ class DevicePoolExecutor:
|
|||
needs_device: Optional[DeviceParams] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.job_count += 1
|
||||
logger.debug("pool job count: %s", self.job_count)
|
||||
if self.job_count > 10:
|
||||
self.finished_jobs += 1
|
||||
logger.debug("pool job count: %s", self.finished_jobs)
|
||||
if self.finished_jobs > self.max_jobs_per_worker:
|
||||
self.recycle()
|
||||
self.job_count = 0
|
||||
self.finished_jobs = 0
|
||||
|
||||
self.prune()
|
||||
device_idx = self.get_next_device(needs_device=needs_device)
|
||||
logger.info(
|
||||
"assigning job %s to device %s: %s",
|
||||
|
@ -197,17 +183,19 @@ class DevicePoolExecutor:
|
|||
queue = self.pending[device.device]
|
||||
queue.put((fn, args, kwargs))
|
||||
|
||||
self.jobs[key] = device.device
|
||||
self.active_job[key] = device.device
|
||||
|
||||
def status(self) -> List[Tuple[str, int, bool, int]]:
|
||||
pending = [
|
||||
(
|
||||
device.device,
|
||||
self.pending[device.device].qsize(),
|
||||
self.progress[device.device].value,
|
||||
self.workers[device.device].is_alive(),
|
||||
name,
|
||||
self.workers[name].is_alive(),
|
||||
context.pending.qsize(),
|
||||
context.cancel.value,
|
||||
context.finished.value,
|
||||
context.progress.value,
|
||||
)
|
||||
for device in self.devices
|
||||
for name, context in self.context.items()
|
||||
]
|
||||
pending.extend(self.finished)
|
||||
return pending
|
||||
|
|
|
@ -2,7 +2,7 @@ from logging import getLogger
|
|||
from traceback import format_exception
|
||||
|
||||
from setproctitle import setproctitle
|
||||
from torch.multiprocessing import Lock, Queue
|
||||
from torch.multiprocessing import Queue
|
||||
|
||||
from ..onnx.torch_before_ort import get_available_providers
|
||||
from ..server import ServerContext, apply_patches
|
||||
|
@ -11,12 +11,11 @@ from .context import WorkerContext
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def logger_init(lock: Lock, logs: Queue):
|
||||
with lock:
|
||||
logger.info("checking in from logger, %s", lock)
|
||||
|
||||
def logger_init(logs: Queue):
|
||||
setproctitle("onnx-web logger")
|
||||
|
||||
logger.info("checking in from logger, %s")
|
||||
|
||||
while True:
|
||||
job = logs.get()
|
||||
with open("worker.log", "w") as f:
|
||||
|
@ -24,31 +23,28 @@ def logger_init(lock: Lock, logs: Queue):
|
|||
f.write(str(job) + "\n\n")
|
||||
|
||||
|
||||
def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
|
||||
with lock:
|
||||
logger.info("checking in from worker, %s, %s", lock, get_available_providers())
|
||||
|
||||
def worker_init(context: WorkerContext, server: ServerContext):
|
||||
apply_patches(server)
|
||||
setproctitle("onnx-web worker: %s" % (context.device.device))
|
||||
|
||||
logger.info("checking in from worker, %s, %s", get_available_providers())
|
||||
|
||||
while True:
|
||||
job = context.pending.get()
|
||||
logger.info("got job: %s", job)
|
||||
try:
|
||||
fn, args, kwargs = job
|
||||
name = args[3][0]
|
||||
|
||||
logger.info("starting job: %s", name)
|
||||
with context.finished.get_lock():
|
||||
context.finished.value = False
|
||||
|
||||
with context.progress.get_lock():
|
||||
context.progress.value = 0
|
||||
|
||||
context.clear_flags()
|
||||
fn(context, *args, **kwargs)
|
||||
logger.info("finished job: %s", name)
|
||||
|
||||
with context.finished.get_lock():
|
||||
context.finished.value = True
|
||||
|
||||
logger.info("job succeeded: %s", name)
|
||||
except Exception as e:
|
||||
logger.error(format_exception(type(e), e, e.__traceback__))
|
||||
logger.error(
|
||||
"error while running job: %s",
|
||||
format_exception(type(e), e, e.__traceback__),
|
||||
)
|
||||
finally:
|
||||
context.set_finished()
|
||||
logger.info("finished job: %s", name)
|
||||
|
|
Loading…
Reference in New Issue