rejoin worker pool
This commit is contained in:
parent
06c74a7a96
commit
6998e8735c
|
@ -4,6 +4,7 @@ from diffusers.utils.logging import disable_progress_bar
|
|||
from flask import Flask
|
||||
from flask_cors import CORS
|
||||
from huggingface_hub.utils.tqdm import disable_progress_bars
|
||||
from torch.multiprocessing import set_start_method
|
||||
|
||||
from .server.api import register_api_routes
|
||||
from .server.static import register_static_routes
|
||||
|
@ -18,6 +19,8 @@ from .worker import DevicePoolExecutor
|
|||
|
||||
|
||||
def main():
|
||||
set_start_method("spawn", force=True)
|
||||
|
||||
context = ServerContext.from_environ()
|
||||
apply_patches(context)
|
||||
check_paths(context)
|
||||
|
@ -42,12 +45,11 @@ def main():
|
|||
register_static_routes(app, context, pool)
|
||||
register_api_routes(app, context, pool)
|
||||
|
||||
return app #, context, pool
|
||||
return app, pool
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# app, context, pool = main()
|
||||
app = main()
|
||||
app, pool = main()
|
||||
app.run("0.0.0.0", 5000, debug=is_debug())
|
||||
# pool.join()
|
||||
pool.join()
|
||||
|
||||
|
|
|
@ -93,6 +93,12 @@ class DevicePoolExecutor:
|
|||
|
||||
return lowest_devices[0]
|
||||
|
||||
def join(self):
|
||||
for device, worker in self.workers.items():
|
||||
if worker.is_alive():
|
||||
logger.info("stopping worker for device %s", device)
|
||||
worker.join(5)
|
||||
|
||||
def prune(self):
|
||||
finished_count = len(self.finished)
|
||||
if finished_count > self.finished_limit:
|
||||
|
|
Loading…
Reference in New Issue