1
0
Fork 0

rejoin worker pool

This commit is contained in:
Sean Sube 2023-02-26 10:47:31 -06:00
parent 06c74a7a96
commit 6998e8735c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 12 additions and 4 deletions

View File

@ -4,6 +4,7 @@ from diffusers.utils.logging import disable_progress_bar
from flask import Flask
from flask_cors import CORS
from huggingface_hub.utils.tqdm import disable_progress_bars
from torch.multiprocessing import set_start_method
from .server.api import register_api_routes
from .server.static import register_static_routes
@ -18,6 +19,8 @@ from .worker import DevicePoolExecutor
def main():
set_start_method("spawn", force=True)
context = ServerContext.from_environ()
apply_patches(context)
check_paths(context)
@ -42,12 +45,11 @@ def main():
register_static_routes(app, context, pool)
register_api_routes(app, context, pool)
return app #, context, pool
return app, pool
if __name__ == "__main__":
# app, context, pool = main()
app = main()
app, pool = main()
app.run("0.0.0.0", 5000, debug=is_debug())
# pool.join()
pool.join()

View File

@ -93,6 +93,12 @@ class DevicePoolExecutor:
return lowest_devices[0]
def join(self):
for device, worker in self.workers.items():
if worker.is_alive():
logger.info("stopping worker for device %s", device)
worker.join(5)
def prune(self):
finished_count = len(self.finished)
if finished_count > self.finished_limit: