From f3ab25f6716822d0bd3bde6fe33ecc8e35b20b51 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Mar 2023 11:30:07 -0500 Subject: [PATCH] lint(api): add start method to worker pool --- api/onnx_web/main.py | 12 ++++++++---- api/onnx_web/worker/pool.py | 10 +++++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index df56781c..21d95533 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -46,14 +46,16 @@ def main(): disable_progress_bar() disable_progress_bars() - app = Flask(__name__) - CORS(app, origins=context.cors_origin) - - # any is a fake device, should not be in the pool + # create workers + # any is a fake device and should not be in the pool pool = DevicePoolExecutor( context, [p for p in get_available_platforms() if p.device != "any"] ) + # create server + app = Flask(__name__) + CORS(app, origins=context.cors_origin) + # register routes register_static_routes(app, context, pool) register_api_routes(app, context, pool) @@ -63,6 +65,7 @@ def main(): def run(): app, pool = main() + pool.start() def quit(): logger.info("shutting down workers") @@ -74,6 +77,7 @@ def run(): if __name__ == "__main__": app, pool = main() + pool.start() app.run("0.0.0.0", 5000, debug=is_debug()) logger.info("shutting down app") pool.join() diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 489f8e3a..1b9e0fca 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -80,12 +80,12 @@ class DevicePoolExecutor: self.logs = Queue(self.max_pending_per_worker) self.rlock = Lock() - # TODO: these should be part of a start method + def start(self) -> None: self.create_health_worker() self.create_logger_worker() self.create_progress_worker() - for device in devices: + for device in self.devices: self.create_device_worker(device) def create_device_worker(self, device: DeviceParams) -> None: @@ -439,7 +439,11 @@ class DevicePoolExecutor: else: self.total_jobs[progress.device] = 1 - logger.debug("updating job count for device %s: %s", progress.device, self.total_jobs[progress.device]) + logger.debug( + "updating job count for device %s: %s", + progress.device, + self.total_jobs[progress.device], + ) # check if the job has been cancelled if progress.job in self.cancelled_jobs: