use progress queue
This commit is contained in:
parent
401ee20526
commit
a37d1a4550
|
@ -22,11 +22,10 @@ class WorkerContext:
|
||||||
key: str,
|
key: str,
|
||||||
device: DeviceParams,
|
device: DeviceParams,
|
||||||
cancel: "Value[bool]" = None,
|
cancel: "Value[bool]" = None,
|
||||||
progress: "Value[int]" = None,
|
|
||||||
finished: "Queue[str]" = None,
|
|
||||||
logs: "Queue[str]" = None,
|
logs: "Queue[str]" = None,
|
||||||
pending: "Queue[Any]" = None,
|
pending: "Queue[Any]" = None,
|
||||||
started: "Queue[Tuple[str, str]]" = None,
|
progress: "Queue[Tuple[str, int]]" = None,
|
||||||
|
finished: "Queue[str]" = None,
|
||||||
):
|
):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -35,7 +34,6 @@ class WorkerContext:
|
||||||
self.finished = finished
|
self.finished = finished
|
||||||
self.logs = logs
|
self.logs = logs
|
||||||
self.pending = pending
|
self.pending = pending
|
||||||
self.started = started
|
|
||||||
|
|
||||||
def is_cancelled(self) -> bool:
|
def is_cancelled(self) -> bool:
|
||||||
return self.cancel.value
|
return self.cancel.value
|
||||||
|
@ -65,14 +63,10 @@ class WorkerContext:
|
||||||
self.cancel.value = cancel
|
self.cancel.value = cancel
|
||||||
|
|
||||||
def set_progress(self, progress: int) -> None:
|
def set_progress(self, progress: int) -> None:
|
||||||
with self.progress.get_lock():
|
self.progress.put((self.key, self.device.device, progress))
|
||||||
self.progress.value = progress
|
|
||||||
|
|
||||||
def put_finished(self, job: str) -> None:
|
def set_finished(self) -> None:
|
||||||
self.finished.put((job, self.device.device))
|
self.finished.put((self.key, self.device.device))
|
||||||
|
|
||||||
def put_started(self, job: str) -> None:
|
|
||||||
self.started.put((job, self.device.device))
|
|
||||||
|
|
||||||
def clear_flags(self) -> None:
|
def clear_flags(self) -> None:
|
||||||
self.set_cancel(False)
|
self.set_cancel(False)
|
||||||
|
|
|
@ -18,7 +18,7 @@ class DevicePoolExecutor:
|
||||||
devices: List[DeviceParams] = None
|
devices: List[DeviceParams] = None
|
||||||
pending: Dict[str, "Queue[WorkerContext]"] = None
|
pending: Dict[str, "Queue[WorkerContext]"] = None
|
||||||
workers: Dict[str, Process] = None
|
workers: Dict[str, Process] = None
|
||||||
active_jobs: Dict[str, str] = None
|
active_jobs: Dict[str, Tuple[str, int]] = None
|
||||||
finished_jobs: List[Tuple[str, int, bool]] = None
|
finished_jobs: List[Tuple[str, int, bool]] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -40,7 +40,7 @@ class DevicePoolExecutor:
|
||||||
self.finished_jobs = []
|
self.finished_jobs = []
|
||||||
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
|
self.total_jobs = 0 # TODO: turn this into a Dict per-worker
|
||||||
|
|
||||||
self.started = Queue()
|
self.progress = Queue()
|
||||||
self.finished = Queue()
|
self.finished = Queue()
|
||||||
|
|
||||||
self.create_logger_worker()
|
self.create_logger_worker()
|
||||||
|
@ -72,11 +72,10 @@ class DevicePoolExecutor:
|
||||||
name,
|
name,
|
||||||
device,
|
device,
|
||||||
cancel=Value("B", False),
|
cancel=Value("B", False),
|
||||||
progress=Value("I", 0),
|
progress=self.progress,
|
||||||
finished=self.finished,
|
finished=self.finished,
|
||||||
logs=self.log_queue,
|
logs=self.log_queue,
|
||||||
pending=pending,
|
pending=pending,
|
||||||
started=self.started,
|
|
||||||
)
|
)
|
||||||
self.context[name] = context
|
self.context[name] = context
|
||||||
self.workers[name] = Process(target=worker_init, args=(context, self.server))
|
self.workers[name] = Process(target=worker_init, args=(context, self.server))
|
||||||
|
@ -85,30 +84,32 @@ class DevicePoolExecutor:
|
||||||
self.workers[name].start()
|
self.workers[name].start()
|
||||||
|
|
||||||
def create_queue_workers(self) -> None:
|
def create_queue_workers(self) -> None:
|
||||||
def started_worker(pending: Queue):
|
def progress_worker(progress: Queue):
|
||||||
logger.info("checking in from started thread")
|
logger.info("checking in from progress worker thread")
|
||||||
while True:
|
while True:
|
||||||
job, device = pending.get()
|
job, device, value = progress.get()
|
||||||
logger.info("job has been started: %s", job)
|
logger.info("progress update for job: %s, %s", job, value)
|
||||||
self.active_jobs[device] = job
|
self.active_jobs[job] = (device, value)
|
||||||
|
|
||||||
def finished_worker(finished: Queue):
|
def finished_worker(finished: Queue):
|
||||||
logger.info("checking in from finished thread")
|
logger.info("checking in from finished worker thread")
|
||||||
while True:
|
while True:
|
||||||
job, device = finished.get()
|
job, device = finished.get()
|
||||||
logger.info("job has been finished: %s", job)
|
logger.info("job has been finished: %s", job)
|
||||||
context = self.get_job_context(job)
|
context = self.context[device]
|
||||||
|
_device, progress = self.active_jobs[job]
|
||||||
self.finished_jobs.append(
|
self.finished_jobs.append(
|
||||||
(job, context.progress.value, context.cancel.value)
|
(job, progress, context.cancel.value)
|
||||||
)
|
)
|
||||||
|
del self.active_jobs[job]
|
||||||
|
|
||||||
self.started_thread = Thread(target=started_worker, args=(self.started,))
|
self.progress_thread = Thread(target=progress_worker, args=(self.progress,))
|
||||||
self.started_thread.start()
|
self.progress_thread.start()
|
||||||
self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
|
self.finished_thread = Thread(target=finished_worker, args=(self.finished,))
|
||||||
self.finished_thread.start()
|
self.finished_thread.start()
|
||||||
|
|
||||||
def get_job_context(self, key: str) -> WorkerContext:
|
def get_job_context(self, key: str) -> WorkerContext:
|
||||||
device = self.active_jobs[key]
|
device, _progress = self.active_jobs[key]
|
||||||
return self.context[device]
|
return self.context[device]
|
||||||
|
|
||||||
def cancel(self, key: str) -> bool:
|
def cancel(self, key: str) -> bool:
|
||||||
|
@ -141,8 +142,8 @@ class DevicePoolExecutor:
|
||||||
return (None, 0)
|
return (None, 0)
|
||||||
|
|
||||||
# TODO: prune here, maybe?
|
# TODO: prune here, maybe?
|
||||||
context = self.get_job_context(key)
|
_device, progress = self.active_jobs[key]
|
||||||
return (False, context.progress.value)
|
return (False, progress)
|
||||||
|
|
||||||
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int:
|
||||||
# respect overrides if possible
|
# respect overrides if possible
|
||||||
|
@ -164,7 +165,7 @@ class DevicePoolExecutor:
|
||||||
return lowest_devices[0]
|
return lowest_devices[0]
|
||||||
|
|
||||||
def join(self):
|
def join(self):
|
||||||
self.started_thread.join(self.join_timeout)
|
self.progress_thread.join(self.join_timeout)
|
||||||
self.finished_thread.join(self.join_timeout)
|
self.finished_thread.join(self.join_timeout)
|
||||||
|
|
||||||
for device, worker in self.workers.items():
|
for device, worker in self.workers.items():
|
||||||
|
@ -226,12 +227,12 @@ class DevicePoolExecutor:
|
||||||
(
|
(
|
||||||
name,
|
name,
|
||||||
self.workers[name].is_alive(),
|
self.workers[name].is_alive(),
|
||||||
context.pending.qsize(),
|
self.context[device].pending.qsize(),
|
||||||
context.cancel.value,
|
self.context[device].cancel.value,
|
||||||
False,
|
False,
|
||||||
context.progress.value,
|
progress,
|
||||||
)
|
)
|
||||||
for name, context in self.context.items()
|
for name, device, progress in self.active_jobs
|
||||||
]
|
]
|
||||||
pending.extend(
|
pending.extend(
|
||||||
[
|
[
|
||||||
|
|
|
@ -37,9 +37,9 @@ def worker_init(context: WorkerContext, server: ServerContext):
|
||||||
name = args[3][0]
|
name = args[3][0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
context.key = name # TODO: hax
|
||||||
context.clear_flags()
|
context.clear_flags()
|
||||||
logger.info("starting job: %s", name)
|
logger.info("starting job: %s", name)
|
||||||
context.put_started(name)
|
|
||||||
fn(context, *args, **kwargs)
|
fn(context, *args, **kwargs)
|
||||||
logger.info("job succeeded: %s", name)
|
logger.info("job succeeded: %s", name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -48,5 +48,5 @@ def worker_init(context: WorkerContext, server: ServerContext):
|
||||||
format_exception(type(e), e, e.__traceback__),
|
format_exception(type(e), e, e.__traceback__),
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
context.put_finished(name)
|
context.set_finished()
|
||||||
logger.info("finished job: %s", name)
|
logger.info("finished job: %s", name)
|
||||||
|
|
Loading…
Reference in New Issue