From 6998e8735ce059208ed04a2320867a7089edbc95 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 26 Feb 2023 10:47:31 -0600 Subject: [PATCH] rejoin worker pool --- api/onnx_web/main.py | 10 ++++++---- api/onnx_web/worker/pool.py | 6 ++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/main.py b/api/onnx_web/main.py index 6381a241..bb19466c 100644 --- a/api/onnx_web/main.py +++ b/api/onnx_web/main.py @@ -4,6 +4,7 @@ from diffusers.utils.logging import disable_progress_bar from flask import Flask from flask_cors import CORS from huggingface_hub.utils.tqdm import disable_progress_bars +from torch.multiprocessing import set_start_method from .server.api import register_api_routes from .server.static import register_static_routes @@ -18,6 +19,8 @@ from .worker import DevicePoolExecutor def main(): + set_start_method("spawn", force=True) + context = ServerContext.from_environ() apply_patches(context) check_paths(context) @@ -42,12 +45,11 @@ def main(): register_static_routes(app, context, pool) register_api_routes(app, context, pool) - return app #, context, pool + return app, pool if __name__ == "__main__": - # app, context, pool = main() - app = main() + app, pool = main() app.run("0.0.0.0", 5000, debug=is_debug()) - # pool.join() + pool.join() diff --git a/api/onnx_web/worker/pool.py b/api/onnx_web/worker/pool.py index 7d83427c..b7eaf3a3 100644 --- a/api/onnx_web/worker/pool.py +++ b/api/onnx_web/worker/pool.py @@ -93,6 +93,12 @@ class DevicePoolExecutor: return lowest_devices[0] + def join(self): + for device, worker in self.workers.items(): + if worker.is_alive(): + logger.info("stopping worker for device %s", device) + worker.join(5) + def prune(self): finished_count = len(self.finished) if finished_count > self.finished_limit: