From 66a20e60fef15234301902a079f7e1959c49f3df Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 27 Feb 2023 17:14:53 -0600 Subject: [PATCH] run logger in a thread, clean up status --- api/onnx_web/server/api.py | 2 +- api/onnx_web/worker/context.py | 14 +++ api/onnx_web/worker/pool.py | 185 ++++++++++++++++++--------------- api/onnx_web/worker/worker.py | 16 +-- 4 files changed, 121 insertions(+), 96 deletions(-) diff --git a/api/onnx_web/server/api.py b/api/onnx_web/server/api.py index 7ff61192..59921246 100644 --- a/api/onnx_web/server/api.py +++ b/api/onnx_web/server/api.py @@ -402,7 +402,7 @@ def blend(context: ServerContext, pool: DevicePoolExecutor): def txt2txt(context: ServerContext, pool: DevicePoolExecutor): device, params, size = pipeline_from_request(context) - output = make_output_name(context, "upscale", params, size) + output = make_output_name(context, "txt2txt", params, size) logger.info("upscale job queued for: %s", output) pool.submit( diff --git a/api/onnx_web/worker/context.py b/api/onnx_web/worker/context.py index 2daf23d0..1ef564a7 100644 --- a/api/onnx_web/worker/context.py +++ b/api/onnx_web/worker/context.py @@ -71,3 +71,17 @@ class WorkerContext: def clear_flags(self) -> None: self.set_cancel(False) self.set_progress(0) + + +class JobStatus: + def __init__( + self, + name: str, + progress: int = 0, + cancelled: bool = False, + finished: bool = False, + ) -> None: + self.name = name + self.progress = progress + self.cancelled = cancelled + self.finished = finished diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index ce45ed8a..dced875d 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -8,7 +8,7 @@ from torch.multiprocessing import Process, Queue, Value from ..params import DeviceParams from ..server import ServerContext from .context import WorkerContext -from .worker import logger_init, worker_init +from .worker import worker_main logger = getLogger(__name__) @@ -18,15 +18,15 @@ class DevicePoolExecutor: devices: List[DeviceParams] = None pending: Dict[str, "Queue[WorkerContext]"] = None workers: Dict[str, Process] = None - active_jobs: Dict[str, Tuple[str, int]] = None - finished_jobs: List[Tuple[str, int, bool]] = None + active_jobs: Dict[str, Tuple[str, int]] = None # should be Dict[Device, JobStatus] + finished_jobs: List[Tuple[str, int, bool]] = None # should be List[JobStatus] def __init__( self, server: ServerContext, devices: List[DeviceParams], max_jobs_per_worker: int = 10, - join_timeout: float = 5.0, + join_timeout: float = 1.0, ): self.server = server self.devices = devices @@ -35,28 +35,27 @@ class DevicePoolExecutor: self.context = {} self.pending = {} + self.threads = {} self.workers = {} + self.active_jobs = {} + self.cancelled_jobs = [] self.finished_jobs = [] self.total_jobs = 0 # TODO: turn this into a Dict per-worker + self.logs = Queue() self.progress = Queue() self.finished = Queue() self.create_logger_worker() - self.create_queue_workers() + self.create_progress_worker() + self.create_finished_worker() + for device in devices: self.create_device_worker(device) logger.debug("testing log worker") - self.log_queue.put("testing") - - def create_logger_worker(self) -> None: - self.log_queue = Queue() - self.logger = Process(target=logger_init, args=(self.log_queue,)) - - logger.debug("starting log worker") - self.logger.start() + self.logs.put("testing") def create_device_worker(self, device: DeviceParams) -> None: name = device.device @@ -74,23 +73,54 @@ class DevicePoolExecutor: cancel=Value("B", False), progress=self.progress, finished=self.finished, - logs=self.log_queue, + logs=self.logs, pending=pending, ) self.context[name] = context - self.workers[name] = Process(target=worker_init, args=(context, self.server)) + self.workers[name] = Process(target=worker_main, args=(context, self.server)) logger.debug("starting worker for device %s", device) self.workers[name].start() - def create_queue_workers(self) -> None: + def create_logger_worker(self) -> None: + def logger_worker(logs: Queue): + logger.info("checking in from logger worker thread") + + while True: + job = logs.get() + with open("worker.log", "w") as f: + logger.info("got log: %s", job) + f.write(str(job) + "\n\n") + + logger_thread = Thread(target=logger_worker, args=(self.logs,)) + self.threads["logger"] = logger_thread + + logger.debug("starting logger worker") + logger_thread.start() + + def create_progress_worker(self) -> None: def progress_worker(progress: Queue): logger.info("checking in from progress worker thread") while True: - job, device, value = progress.get() - logger.info("progress update for job: %s, %s", job, value) - self.active_jobs[job] = (device, value) + try: + job, device, value = progress.get() + logger.info("progress update for job: %s to %s", job, value) + self.active_jobs[job] = (device, value) + if job in self.cancelled_jobs: + logger.debug( + "setting flag for cancelled job: %s on %s", job, device + ) + self.context[device].set_cancel() + except Exception as err: + logger.error("error during progress update", err) + progress_thread = Thread(target=progress_worker, args=(self.progress,)) + self.threads["progress"] = progress_thread + + logger.debug("starting progress worker") + progress_thread.start() + + def create_finished_worker(self) -> None: def finished_worker(finished: Queue): logger.info("checking in from finished worker thread") while True: @@ -98,53 +128,19 @@ class DevicePoolExecutor: logger.info("job has been finished: %s", job) context = self.context[device] _device, progress = self.active_jobs[job] - self.finished_jobs.append( - (job, progress, context.cancel.value) - ) + self.finished_jobs.append((job, progress, context.cancel.value)) del self.active_jobs[job] - self.progress_thread = Thread(target=progress_worker, args=(self.progress,)) - self.progress_thread.start() - self.finished_thread = Thread(target=finished_worker, args=(self.finished,)) - self.finished_thread.start() + finished_thread = Thread(target=finished_worker, args=(self.finished,)) + self.thread["finished"] = finished_thread + + logger.debug("started finished worker") + finished_thread.start() def get_job_context(self, key: str) -> WorkerContext: device, _progress = self.active_jobs[key] return self.context[device] - def cancel(self, key: str) -> bool: - """ - Cancel a job. If the job has not been started, this will cancel - the future and never execute it. If the job has been started, it - should be cancelled on the next progress callback. - """ - if key not in self.active_jobs: - logger.warn("attempting to cancel unknown job: %s", key) - return False - - device, _progress = self.active_jobs[key] - context = self.context[device] - logger.info("cancelling job %s on device %s", key, device) - - if context.cancel.get_lock(): - context.cancel.value = True - - # self.finished.append((key, context.progress.value, context.cancel.value)) maybe? - return True - - def done(self, key: str) -> Tuple[Optional[bool], int]: - for k, p, c in self.finished_jobs: - if k == key: - return (True, p) - - if key not in self.active_jobs: - logger.warn("checking status for unknown job: %s", key) - return (None, 0) - - # TODO: prune here, maybe? - _device, progress = self.active_jobs[key] - return (False, progress) - def get_next_device(self, needs_device: Optional[DeviceParams] = None) -> int: # respect overrides if possible if needs_device is not None: @@ -164,6 +160,45 @@ class DevicePoolExecutor: return lowest_devices[0] + def cancel(self, key: str) -> bool: + """ + Cancel a job. If the job has not been started, this will cancel + the future and never execute it. If the job has been started, it + should be cancelled on the next progress callback. + """ + + self.cancelled_jobs.append(key) + + if key not in self.active_jobs: + logger.debug("cancelled job has not been started yet: %s", key) + return True + + device, _progress = self.active_jobs[key] + logger.info("cancelling job %s, active on device %s", key, device) + + context = self.context[device] + context.set_cancel() + + return True + + def done(self, key: str) -> Tuple[Optional[bool], int]: + """ + Check if a job has been finished and report the last progress update. + + If the job is still active or pending, the first item will be False. + If the job is not finished or active, the first item will be None. + """ + for k, p, c in self.finished_jobs: + if k == key: + return (True, p) + + if key not in self.active_jobs: + logger.warn("checking status for unknown job: %s", key) + return (None, 0) + + _device, progress = self.active_jobs[key] + return (False, progress) + def join(self): self.progress_thread.join(self.join_timeout) self.finished_thread.join(self.join_timeout) @@ -216,35 +251,23 @@ class DevicePoolExecutor: self.devices[device_idx], ) - device = self.devices[device_idx] - queue = self.pending[device.device] - queue.put((fn, args, kwargs)) + device = self.devices[device_idx].device + self.pending[device].put((fn, args, kwargs)) - self.active_jobs[key] = (device.device, 0) - - def status(self) -> List[Tuple[str, int, bool, int]]: - pending = [ - ( - name, - self.workers[name].is_alive(), - self.context[device].pending.qsize(), - self.context[device].cancel.value, - False, - progress, - ) - for name, device, progress in self.active_jobs.items() + def status(self) -> List[Tuple[str, int, bool, bool]]: + history = [ + (name, progress, False, name in self.cancelled_jobs) + for name, _device, progress in self.active_jobs.items() ] - pending.extend( + history.extend( [ ( name, - False, - 0, - cancel, - True, progress, + True, + cancel, ) for name, progress, cancel in self.finished_jobs ] ) - return pending + return history diff --git a/api/onnx_web/worker/worker.py b/api/onnx_web/worker/worker.py index cbd3afa7..9518f948 100644 --- a/api/onnx_web/worker/worker.py +++ b/api/onnx_web/worker/worker.py @@ -11,19 +11,7 @@ from .context import WorkerContext logger = getLogger(__name__) -def logger_init(logs: Queue): - setproctitle("onnx-web logger") - - logger.info("checking in from logger") - - while True: - job = logs.get() - with open("worker.log", "w") as f: - logger.info("got log: %s", job) - f.write(str(job) + "\n\n") - - -def worker_init(context: WorkerContext, server: ServerContext): +def worker_main(context: WorkerContext, server: ServerContext): apply_patches(server) setproctitle("onnx-web worker: %s" % (context.device.device)) @@ -37,7 +25,7 @@ def worker_init(context: WorkerContext, server: ServerContext): name = args[3][0] try: - context.key = name # TODO: hax + context.key = name # TODO: hax context.clear_flags() logger.info("starting job: %s", name) fn(context, *args, **kwargs)