1
0
Fork 0

use progress queue

This commit is contained in:
Sean Sube 2023-02-26 20:37:22 -06:00
parent 401ee20526
commit a37d1a4550
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 30 additions and 35 deletions

View File

@ -22,11 +22,10 @@ class WorkerContext:
key: str, key: str,
device: DeviceParams, device: DeviceParams,
cancel: "Value[bool]" = None, cancel: "Value[bool]" = None,
progress: "Value[int]" = None,
finished: "Queue[str]" = None,
logs: "Queue[str]" = None, logs: "Queue[str]" = None,
pending: "Queue[Any]" = None, pending: "Queue[Any]" = None,
started: "Queue[Tuple[str, str]]" = None, progress: "Queue[Tuple[str, int]]" = None,
finished: "Queue[str]" = None,
): ):
self.key = key self.key = key
self.device = device self.device = device
@ -35,7 +34,6 @@ class WorkerContext:
self.finished = finished self.finished = finished
self.logs = logs self.logs = logs
self.pending = pending self.pending = pending
self.started = started
def is_cancelled(self) -> bool: def is_cancelled(self) -> bool:
return self.cancel.value return self.cancel.value
@ -65,14 +63,10 @@ class WorkerContext:
self.cancel.value = cancel self.cancel.value = cancel
def set_progress(self, progress: int) -> None: def set_progress(self, progress: int) -> None:
with self.progress.get_lock(): self.progress.put((self.key, self.device.device, progress))
self.progress.value = progress
def put_finished(self, job: str) -> None: def set_finished(self) -> None:
self.finished.put((job, self.device.device)) self.finished.put((self.key, self.device.device))
def put_started(self, job: str) -> None:
self.started.put((job, self.device.device))
def clear_flags(self) -> None: def clear_flags(self) -> None:
self.set_cancel(False) self.set_cancel(False)

View File

@ -18,7 +18,7 @@ class DevicePoolExecutor:
devices: List[DeviceParams] = None devices: List[DeviceParams] = None
pending: Dict[str, "Queue[WorkerContext]"] = None pending: Dict[str, "Queue[WorkerContext]"] = None
workers: Dict[str, Process] = None workers: Dict[str, Process] = None
active_jobs: Dict[str, str] = None active_jobs: Dict[str, Tuple[str, int]] = None
finished_jobs: List[Tuple[str, int, bool]] = None finished_jobs: List[Tuple[str, int, bool]] = None
def __init__( def __init__(
@ -40,7 +40,7 @@ class DevicePoolExecutor:
self.finished_jobs = [] self.finished_jobs = []
self.total_jobs = 0 # TODO: turn this into a Dict per-worker self.total_jobs = 0 # TODO: turn this into a Dict per-worker
self.started = Queue() self.progress = Queue()
self.finished = Queue() self.finished = Queue()
self.create_logger_worker() self.create_logger_worker()
@ -72,11 +72,10 @@ class DevicePoolExecutor:
name, name,
device, device,
cancel=Value("B", False), cancel=Value("B", False),
progress=Value("I", 0), progress=self.progress,
finished=self.finished, finished=self.finished,
logs=self.log_queue, logs=self.log_queue,
pending=pending, pending=pending,
started=self.started,
) )
self.context[name] = context self.context[name] = context
self.workers[name] = Process(target=worker_init, args=(context, self.server)) self.workers[name] = Process(target=worker_init, args=(context, self.server))
@ -85,30 +84,32 @@ class DevicePoolExecutor:
self.workers[name].start() self.workers[name].start()
def create_queue_workers(self) -> None: def create_queue_workers(self) -> None:
def started_worker(pending: Queue): def progress_worker(progress: Queue):
logger.info("checking in from started thread") logger.info("checking in from progress worker thread")
while True: while True:
job, device = pending.get() job, device, value = progress.get()
logger.info("job has been started: %s", job) logger.info("progress update for job: %s, %s", job, value)
self.active_jobs[device] = job self.active_jobs[job] = (device, value)
def finished_worker(finished: Queue): def finished_worker(finished: Queue):
logger.info("checking in from finished thread") logger.info("checking in from finished worker thread")
while True: while True:
job, device = finished.get() job, device = finished.get()
logger.info("job has been finished: %s", job) logger.info("job has been finished: %s", job)
context = self.get_job_context(job) context = self.context[device]
_device, progress = self.active_jobs[job]
self.finished_jobs.append( self.finished_jobs.append(
(job, context.progress.value, context.cancel.value) (job, progress, context.cancel.value)
) )
del self.active_jobs[job]
self.started_thread = Thread(target=started_worker, args=(self.started,)) self.progress_thread = Thread(target=progress_worker, args=(self.progress,))
self.started_thread.start() self.progress_thread.start()
self.finished_thread = Thread(target=finished_worker, args=(self.finished,)) self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
self.finished_thread.start() self.finished_thread.start()
def get_job_context(self, key: str) -> WorkerContext: def get_job_context(self, key: str) -> WorkerContext:
device = self.active_jobs[key] device, _progress = self.active_jobs[key]
return self.context[device] return self.context[device]
def cancel(self, key: str) -> bool: def cancel(self, key: str) -> bool:
@ -141,8 +142,8 @@ class DevicePoolExecutor:
return (None, 0) return (None, 0)
# TODO: prune here, maybe? # TODO: prune here, maybe?
context = self.get_job_context(key) _device, progress = self.active_jobs[key]
return (False, context.progress.value) return (False, progress)
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
# respect overrides if possible # respect overrides if possible
@ -164,7 +165,7 @@ class DevicePoolExecutor:
return lowest_devices[0] return lowest_devices[0]
def join(self): def join(self):
self.started_thread.join(self.join_timeout) self.progress_thread.join(self.join_timeout)
self.finished_thread.join(self.join_timeout) self.finished_thread.join(self.join_timeout)
for device, worker in self.workers.items(): for device, worker in self.workers.items():
@ -226,12 +227,12 @@ class DevicePoolExecutor:
( (
name, name,
self.workers[name].is_alive(), self.workers[name].is_alive(),
context.pending.qsize(), self.context[device].pending.qsize(),
context.cancel.value, self.context[device].cancel.value,
False, False,
context.progress.value, progress,
) )
for name, context in self.context.items() for name, device, progress in self.active_jobs
] ]
pending.extend( pending.extend(
[ [

View File

@ -37,9 +37,9 @@ def worker_init(context: WorkerContext, server: ServerContext):
name = args[3][0] name = args[3][0]
try: try:
context.key = name # TODO: hax
context.clear_flags() context.clear_flags()
logger.info("starting job: %s", name) logger.info("starting job: %s", name)
context.put_started(name)
fn(context, *args, **kwargs) fn(context, *args, **kwargs)
logger.info("job succeeded: %s", name) logger.info("job succeeded: %s", name)
except Exception as e: except Exception as e:
@ -48,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext):
format_exception(type(e), e, e.__traceback__), format_exception(type(e), e, e.__traceback__),
) )
finally: finally:
context.put_finished(name) context.set_finished()
logger.info("finished job: %s", name) logger.info("finished job: %s", name)