1
0
Fork 0

clear worker flags between jobs, attempt to record finished jobs again

This commit is contained in:
Sean Sube 2023-02-26 15:06:40 -06:00
parent d1961afdbc
commit 85118d17c6
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 109 additions and 116 deletions

View File

@ -20,12 +20,12 @@ class WorkerContext:
def __init__( def __init__(
self, self,
key: str, key: str,
cancel: "Value[bool]",
device: DeviceParams, device: DeviceParams,
pending: "Queue[Any]", cancel: "Value[bool]" = None,
progress: "Value[int]", finished: "Value[bool]" = None,
logs: "Queue[str]", progress: "Value[int]" = None,
finished: "Value[bool]", logs: "Queue[str]" = None,
pending: "Queue[Any]" = None,
): ):
self.key = key self.key = key
self.cancel = cancel self.cancel = cancel
@ -62,6 +62,15 @@ class WorkerContext:
with self.cancel.get_lock(): with self.cancel.get_lock():
self.cancel.value = cancel self.cancel.value = cancel
def set_finished(self, finished: bool = True) -> None:
with self.finished.get_lock():
self.finished.value = finished
def set_progress(self, progress: int) -> None: def set_progress(self, progress: int) -> None:
with self.progress.get_lock(): with self.progress.get_lock():
self.progress.value = progress self.progress.value = progress
def clear_flags(self) -> None:
self.set_cancel(False)
self.set_finished(False)
self.set_progress(0)

View File

@ -3,7 +3,7 @@ from logging import getLogger
from multiprocessing import Queue from multiprocessing import Queue
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
from torch.multiprocessing import Lock, Process, Value from torch.multiprocessing import Process, Value
from ..params import DeviceParams from ..params import DeviceParams
from ..server import ServerContext from ..server import ServerContext
@ -14,97 +14,102 @@ logger = getLogger(__name__)
class DevicePoolExecutor: class DevicePoolExecutor:
context: Dict[str, WorkerContext] = None
devices: List[DeviceParams] = None devices: List[DeviceParams] = None
finished: Dict[str, "Value[bool]"] = None
pending: Dict[str, "Queue[WorkerContext]"] = None pending: Dict[str, "Queue[WorkerContext]"] = None
progress: Dict[str, "Value[int]"] = None
workers: Dict[str, Process] = None workers: Dict[str, Process] = None
jobs: Dict[str, str] = None active_job: Dict[str, str] = None
finished: List[Tuple[str, int, bool]] = None
def __init__( def __init__(
self, self,
server: ServerContext, server: ServerContext,
devices: List[DeviceParams], devices: List[DeviceParams],
finished_limit: int = 10, max_jobs_per_worker: int = 10,
join_timeout: float = 5.0,
): ):
self.server = server self.server = server
self.devices = devices self.devices = devices
self.finished = {} self.max_jobs_per_worker = max_jobs_per_worker
self.finished_limit = finished_limit self.join_timeout = join_timeout
self.context = {}
self.locks = {}
self.pending = {}
self.progress = {}
self.workers = {}
self.jobs = {} # Dict[Output, Device]
self.job_count = 0
# TODO: make this a method self.context = {}
logger.debug("starting log worker") self.pending = {}
self.log_queue = Queue() self.workers = {}
log_lock = Lock() self.active_job = {}
self.locks["logger"] = log_lock self.finished_jobs = 0 # TODO: turn this into a Dict per-worker
self.logger = Process(target=logger_init, args=(log_lock, self.log_queue))
self.logger.start() self.create_logger_worker()
for device in devices:
self.create_device_worker(device)
logger.debug("testing log worker") logger.debug("testing log worker")
self.log_queue.put("testing") self.log_queue.put("testing")
# create a pending queue and progress value for each device def create_logger_worker(self) -> None:
for device in devices: self.log_queue = Queue()
self.logger = Process(target=logger_init, args=(self.log_queue))
logger.debug("starting log worker")
self.logger.start()
def create_device_worker(self, device: DeviceParams) -> None:
name = device.device name = device.device
# TODO: make this a method
lock = Lock()
self.locks[name] = lock
cancel = Value("B", False, lock=lock)
finished = Value("B", False)
self.finished[name] = finished
progress = Value(
"I", 0
) # , lock=lock) # needs its own lock for some reason. TODO: why?
self.progress[name] = progress
pending = Queue() pending = Queue()
self.pending[name] = pending self.pending[name] = pending
context = WorkerContext( context = WorkerContext(
name, cancel, device, pending, progress, self.log_queue, finished name,
device,
cancel=Value("B", False),
finished=Value("B", False),
progress=Value("I", 0),
pending=pending,
logs=self.log_queue,
) )
self.context[name] = context self.context[name] = context
self.workers[name] = Process(target=worker_init, args=(context, self.server))
logger.debug("starting worker for device %s", device) logger.debug("starting worker for device %s", device)
self.workers[name] = Process(
target=worker_init, args=(lock, context, server)
)
self.workers[name].start() self.workers[name].start()
def create_prune_worker(self) -> None:
# TODO: create a background thread to prune completed jobs
pass
def cancel(self, key: str) -> bool: def cancel(self, key: str) -> bool:
""" """
Cancel a job. If the job has not been started, this will cancel Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback. should be cancelled on the next progress callback.
""" """
if key not in self.jobs: if key not in self.active_job:
logger.warn("attempting to cancel unknown job: %s", key) logger.warn("attempting to cancel unknown job: %s", key)
return False return False
device = self.jobs[key] device = self.active_job[key]
cancel = self.context[device].cancel context = self.context[device]
logger.info("cancelling job %s on device %s", key, device) logger.info("cancelling job %s on device %s", key, device)
if cancel.get_lock(): if context.cancel.get_lock():
cancel.value = True context.cancel.value = True
# self.finished.append((key, context.progress.value, context.cancel.value)) maybe?
return True return True
def done(self, key: str) -> Tuple[Optional[bool], int]: def done(self, key: str) -> Tuple[Optional[bool], int]:
if key not in self.jobs: if key not in self.active_job:
logger.warn("checking status for unknown job: %s", key) logger.warn("checking status for unknown job: %s", key)
return (None, 0) return (None, 0)
device = self.jobs[key] # TODO: prune here, maybe?
finished = self.finished[device]
progress = self.progress[device]
return (finished.value, progress.value) device = self.active_job[key]
context = self.context[device]
if context.finished.value is True:
self.finished.append((key, context.progress.value, context.cancel.value))
return (context.finished.value, context.progress.value)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible # respect overrides if possible
@ -113,9 +118,8 @@ class DevicePoolExecutor:
if self.devices[i].device == needs_device.device: if self.devices[i].device == needs_device.device:
return i return i
pending = [self.pending[d.device].qsize() for d in self.devices]
jobs = Counter(range(len(self.devices))) jobs = Counter(range(len(self.devices)))
jobs.update(pending) jobs.update([self.pending[d.device].qsize() for d in self.devices])
queued = jobs.most_common() queued = jobs.most_common()
logger.debug("jobs queued by device: %s", queued) logger.debug("jobs queued by device: %s", queued)
@ -130,26 +134,16 @@ class DevicePoolExecutor:
for device, worker in self.workers.items(): for device, worker in self.workers.items():
if worker.is_alive(): if worker.is_alive():
logger.info("stopping worker for device %s", device) logger.info("stopping worker for device %s", device)
worker.join(5) worker.join(self.join_timeout)
if self.logger.is_alive(): if self.logger.is_alive():
self.logger.join(5) self.logger.join(self.join_timeout)
def prune(self):
finished_count = len(self.finished)
if finished_count > self.finished_limit:
logger.debug(
"pruning %s of %s finished jobs",
finished_count - self.finished_limit,
finished_count,
)
self.finished[:] = self.finished[-self.finished_limit :]
def recycle(self): def recycle(self):
for name, proc in self.workers.items(): for name, proc in self.workers.items():
if proc.is_alive(): if proc.is_alive():
logger.debug("shutting down worker for device %s", name) logger.debug("shutting down worker for device %s", name)
proc.join(5) proc.join(self.join_timeout)
proc.terminate() proc.terminate()
else: else:
logger.warning("worker for device %s has died", name) logger.warning("worker for device %s has died", name)
@ -159,15 +153,8 @@ class DevicePoolExecutor:
logger.info("starting new workers") logger.info("starting new workers")
for name in self.workers.keys(): for device in self.devices:
context = self.context[name] self.create_device_worker(device)
lock = self.locks[name]
logger.debug("starting worker for device %s", name)
self.workers[name] = Process(
target=worker_init, args=(lock, context, self.server)
)
self.workers[name].start()
def submit( def submit(
self, self,
@ -178,13 +165,12 @@ class DevicePoolExecutor:
needs_device: Optional[DeviceParams] = None, needs_device: Optional[DeviceParams] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
self.job_count += 1 self.finished_jobs += 1
logger.debug("pool job count: %s", self.job_count) logger.debug("pool job count: %s", self.finished_jobs)
if self.job_count > 10: if self.finished_jobs > self.max_jobs_per_worker:
self.recycle() self.recycle()
self.job_count = 0 self.finished_jobs = 0
self.prune()
device_idx = self.get_next_device(needs_device=needs_device) device_idx = self.get_next_device(needs_device=needs_device)
logger.info( logger.info(
"assigning job %s to device %s: %s", "assigning job %s to device %s: %s",
@ -197,17 +183,19 @@ class DevicePoolExecutor:
queue = self.pending[device.device] queue = self.pending[device.device]
queue.put((fn, args, kwargs)) queue.put((fn, args, kwargs))
self.jobs[key] = device.device self.active_job[key] = device.device
def status(self) -> List[Tuple[str, int, bool, int]]: def status(self) -> List[Tuple[str, int, bool, int]]:
pending = [ pending = [
( (
device.device, name,
self.pending[device.device].qsize(), self.workers[name].is_alive(),
self.progress[device.device].value, context.pending.qsize(),
self.workers[device.device].is_alive(), context.cancel.value,
context.finished.value,
context.progress.value,
) )
for device in self.devices for name, context in self.context.items()
] ]
pending.extend(self.finished) pending.extend(self.finished)
return pending return pending

View File

@ -2,7 +2,7 @@ from logging import getLogger
from traceback import format_exception from traceback import format_exception
from setproctitle import setproctitle from setproctitle import setproctitle
from torch.multiprocessing import Lock, Queue from torch.multiprocessing import Queue
from ..onnx.torch_before_ort import get_available_providers from ..onnx.torch_before_ort import get_available_providers
from ..server import ServerContext, apply_patches from ..server import ServerContext, apply_patches
@ -11,12 +11,11 @@ from .context import WorkerContext
logger = getLogger(__name__) logger = getLogger(__name__)
def logger_init(lock: Lock, logs: Queue): def logger_init(logs: Queue):
with lock:
logger.info("checking in from logger, %s", lock)
setproctitle("onnx-web logger") setproctitle("onnx-web logger")
logger.info("checking in from logger, %s")
while True: while True:
job = logs.get() job = logs.get()
with open("worker.log", "w") as f: with open("worker.log", "w") as f:
@ -24,31 +23,28 @@ def logger_init(lock: Lock, logs: Queue):
f.write(str(job) + "\n\n") f.write(str(job) + "\n\n")
def worker_init(lock: Lock, context: WorkerContext, server: ServerContext): def worker_init(context: WorkerContext, server: ServerContext):
with lock:
logger.info("checking in from worker, %s, %s", lock, get_available_providers())
apply_patches(server) apply_patches(server)
setproctitle("onnx-web worker: %s" % (context.device.device)) setproctitle("onnx-web worker: %s" % (context.device.device))
logger.info("checking in from worker, %s, %s", get_available_providers())
while True: while True:
job = context.pending.get() job = context.pending.get()
logger.info("got job: %s", job) logger.info("got job: %s", job)
try: try:
fn, args, kwargs = job fn, args, kwargs = job
name = args[3][0] name = args[3][0]
logger.info("starting job: %s", name) logger.info("starting job: %s", name)
with context.finished.get_lock(): context.clear_flags()
context.finished.value = False
with context.progress.get_lock():
context.progress.value = 0
fn(context, *args, **kwargs) fn(context, *args, **kwargs)
logger.info("finished job: %s", name) logger.info("job succeeded: %s", name)
with context.finished.get_lock():
context.finished.value = True
except Exception as e: except Exception as e:
logger.error(format_exception(type(e), e, e.__traceback__)) logger.error(
"error while running job: %s",
format_exception(type(e), e, e.__traceback__),
)
finally:
context.set_finished()
logger.info("finished job: %s", name)