1
0
Fork 0

clear worker flags between jobs, attempt to record finished jobs again

This commit is contained in:
Sean Sube 2023-02-26 15:06:40 -06:00
parent d1961afdbc
commit 85118d17c6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 109 additions and 116 deletions

View File

@ -20,12 +20,12 @@ class WorkerContext:
def __init__(
self,
key: str,
cancel: "Value[bool]",
device: DeviceParams,
pending: "Queue[Any]",
progress: "Value[int]",
logs: "Queue[str]",
finished: "Value[bool]",
cancel: "Value[bool]" = None,
finished: "Value[bool]" = None,
progress: "Value[int]" = None,
logs: "Queue[str]" = None,
pending: "Queue[Any]" = None,
):
self.key = key
self.cancel = cancel
@ -62,6 +62,15 @@ class WorkerContext:
with self.cancel.get_lock():
self.cancel.value = cancel
def set_finished(self, finished: bool = True) -> None:
with self.finished.get_lock():
self.finished.value = finished
def set_progress(self, progress: int) -> None:
with self.progress.get_lock():
self.progress.value = progress
def clear_flags(self) -> None:
self.set_cancel(False)
self.set_finished(False)
self.set_progress(0)

View File

@ -3,7 +3,7 @@ from logging import getLogger
from multiprocessing import Queue
from typing import Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Lock, Process, Value
from torch.multiprocessing import Process, Value
from ..params import DeviceParams
from ..server import ServerContext
@ -14,67 +14,67 @@ logger = getLogger(__name__)
class DevicePoolExecutor:
context: Dict[str, WorkerContext] = None
devices: List[DeviceParams] = None
finished: Dict[str, "Value[bool]"] = None
pending: Dict[str, "Queue[WorkerContext]"] = None
progress: Dict[str, "Value[int]"] = None
workers: Dict[str, Process] = None
jobs: Dict[str, str] = None
active_job: Dict[str, str] = None
finished: List[Tuple[str, int, bool]] = None
def __init__(
self,
server: ServerContext,
devices: List[DeviceParams],
finished_limit: int = 10,
max_jobs_per_worker: int = 10,
join_timeout: float = 5.0,
):
self.server = server
self.devices = devices
self.finished = {}
self.finished_limit = finished_limit
self.context = {}
self.locks = {}
self.pending = {}
self.progress = {}
self.workers = {}
self.jobs = {} # Dict[Output, Device]
self.job_count = 0
self.max_jobs_per_worker = max_jobs_per_worker
self.join_timeout = join_timeout
# TODO: make this a method
logger.debug("starting log worker")
self.log_queue = Queue()
log_lock = Lock()
self.locks["logger"] = log_lock
self.logger = Process(target=logger_init, args=(log_lock, self.log_queue))
self.logger.start()
self.context = {}
self.pending = {}
self.workers = {}
self.active_job = {}
self.finished_jobs = 0 # TODO: turn this into a Dict per-worker
self.create_logger_worker()
for device in devices:
self.create_device_worker(device)
logger.debug("testing log worker")
self.log_queue.put("testing")
# create a pending queue and progress value for each device
for device in devices:
name = device.device
# TODO: make this a method
lock = Lock()
self.locks[name] = lock
cancel = Value("B", False, lock=lock)
finished = Value("B", False)
self.finished[name] = finished
progress = Value(
"I", 0
) # , lock=lock) # needs its own lock for some reason. TODO: why?
self.progress[name] = progress
pending = Queue()
self.pending[name] = pending
context = WorkerContext(
name, cancel, device, pending, progress, self.log_queue, finished
)
self.context[name] = context
def create_logger_worker(self) -> None:
self.log_queue = Queue()
self.logger = Process(target=logger_init, args=(self.log_queue))
logger.debug("starting worker for device %s", device)
self.workers[name] = Process(
target=worker_init, args=(lock, context, server)
)
self.workers[name].start()
logger.debug("starting log worker")
self.logger.start()
def create_device_worker(self, device: DeviceParams) -> None:
name = device.device
pending = Queue()
self.pending[name] = pending
context = WorkerContext(
name,
device,
cancel=Value("B", False),
finished=Value("B", False),
progress=Value("I", 0),
pending=pending,
logs=self.log_queue,
)
self.context[name] = context
self.workers[name] = Process(target=worker_init, args=(context, self.server))
logger.debug("starting worker for device %s", device)
self.workers[name].start()
def create_prune_worker(self) -> None:
# TODO: create a background thread to prune completed jobs
pass
def cancel(self, key: str) -> bool:
"""
@ -82,29 +82,34 @@ class DevicePoolExecutor:
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
"""
if key not in self.jobs:
if key not in self.active_job:
logger.warn("attempting to cancel unknown job: %s", key)
return False
device = self.jobs[key]
cancel = self.context[device].cancel
device = self.active_job[key]
context = self.context[device]
logger.info("cancelling job %s on device %s", key, device)
if cancel.get_lock():
cancel.value = True
if context.cancel.get_lock():
context.cancel.value = True
# self.finished.append((key, context.progress.value, context.cancel.value)) maybe?
return True
def done(self, key: str) -> Tuple[Optional[bool], int]:
if key not in self.jobs:
if key not in self.active_job:
logger.warn("checking status for unknown job: %s", key)
return (None, 0)
device = self.jobs[key]
finished = self.finished[device]
progress = self.progress[device]
# TODO: prune here, maybe?
return (finished.value, progress.value)
device = self.active_job[key]
context = self.context[device]
if context.finished.value is True:
self.finished.append((key, context.progress.value, context.cancel.value))
return (context.finished.value, context.progress.value)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible
@ -113,9 +118,8 @@ class DevicePoolExecutor:
if self.devices[i].device == needs_device.device:
return i
pending = [self.pending[d.device].qsize() for d in self.devices]
jobs = Counter(range(len(self.devices)))
jobs.update(pending)
jobs.update([self.pending[d.device].qsize() for d in self.devices])
queued = jobs.most_common()
logger.debug("jobs queued by device: %s", queued)
@ -130,26 +134,16 @@ class DevicePoolExecutor:
for device, worker in self.workers.items():
if worker.is_alive():
logger.info("stopping worker for device %s", device)
worker.join(5)
worker.join(self.join_timeout)
if self.logger.is_alive():
self.logger.join(5)
def prune(self):
finished_count = len(self.finished)
if finished_count > self.finished_limit:
logger.debug(
"pruning %s of %s finished jobs",
finished_count - self.finished_limit,
finished_count,
)
self.finished[:] = self.finished[-self.finished_limit :]
self.logger.join(self.join_timeout)
def recycle(self):
for name, proc in self.workers.items():
if proc.is_alive():
logger.debug("shutting down worker for device %s", name)
proc.join(5)
proc.join(self.join_timeout)
proc.terminate()
else:
logger.warning("worker for device %s has died", name)
@ -159,15 +153,8 @@ class DevicePoolExecutor:
logger.info("starting new workers")
for name in self.workers.keys():
context = self.context[name]
lock = self.locks[name]
logger.debug("starting worker for device %s", name)
self.workers[name] = Process(
target=worker_init, args=(lock, context, self.server)
)
self.workers[name].start()
for device in self.devices:
self.create_device_worker(device)
def submit(
self,
@ -178,13 +165,12 @@ class DevicePoolExecutor:
needs_device: Optional[DeviceParams] = None,
**kwargs,
) -> None:
self.job_count += 1
logger.debug("pool job count: %s", self.job_count)
if self.job_count > 10:
self.finished_jobs += 1
logger.debug("pool job count: %s", self.finished_jobs)
if self.finished_jobs > self.max_jobs_per_worker:
self.recycle()
self.job_count = 0
self.finished_jobs = 0
self.prune()
device_idx = self.get_next_device(needs_device=needs_device)
logger.info(
"assigning job %s to device %s: %s",
@ -197,17 +183,19 @@ class DevicePoolExecutor:
queue = self.pending[device.device]
queue.put((fn, args, kwargs))
self.jobs[key] = device.device
self.active_job[key] = device.device
def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [
(
device.device,
self.pending[device.device].qsize(),
self.progress[device.device].value,
self.workers[device.device].is_alive(),
name,
self.workers[name].is_alive(),
context.pending.qsize(),
context.cancel.value,
context.finished.value,
context.progress.value,
)
for device in self.devices
for name, context in self.context.items()
]
pending.extend(self.finished)
return pending

View File

@ -2,7 +2,7 @@ from logging import getLogger
from traceback import format_exception
from setproctitle import setproctitle
from torch.multiprocessing import Lock, Queue
from torch.multiprocessing import Queue
from ..onnx.torch_before_ort import get_available_providers
from ..server import ServerContext, apply_patches
@ -11,12 +11,11 @@ from .context import WorkerContext
logger = getLogger(__name__)
def logger_init(lock: Lock, logs: Queue):
with lock:
logger.info("checking in from logger, %s", lock)
def logger_init(logs: Queue):
setproctitle("onnx-web logger")
logger.info("checking in from logger, %s")
while True:
job = logs.get()
with open("worker.log", "w") as f:
@ -24,31 +23,28 @@ def logger_init(lock: Lock, logs: Queue):
f.write(str(job) + "\n\n")
def worker_init(lock: Lock, context: WorkerContext, server: ServerContext):
with lock:
logger.info("checking in from worker, %s, %s", lock, get_available_providers())
def worker_init(context: WorkerContext, server: ServerContext):
apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device))
logger.info("checking in from worker, %s, %s", get_available_providers())
while True:
job = context.pending.get()
logger.info("got job: %s", job)
try:
fn, args, kwargs = job
name = args[3][0]
logger.info("starting job: %s", name)
with context.finished.get_lock():
context.finished.value = False
with context.progress.get_lock():
context.progress.value = 0
context.clear_flags()
fn(context, *args, **kwargs)
logger.info("finished job: %s", name)
with context.finished.get_lock():
context.finished.value = True
logger.info("job succeeded: %s", name)
except Exception as e:
logger.error(format_exception(type(e), e, e.__traceback__))
logger.error(
"error while running job: %s",
format_exception(type(e), e, e.__traceback__),
)
finally:
context.set_finished()
logger.info("finished job: %s", name)