rejoin worker pool
This commit is contained in:
parent
06c74a7a96
commit
6998e8735c
|
@ -4,6 +4,7 @@ from diffusers.utils.logging import disable_progress_bar
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from huggingface_hub.utils.tqdm import disable_progress_bars
|
from huggingface_hub.utils.tqdm import disable_progress_bars
|
||||||
|
from torch.multiprocessing import set_start_method
|
||||||
|
|
||||||
from .server.api import register_api_routes
|
from .server.api import register_api_routes
|
||||||
from .server.static import register_static_routes
|
from .server.static import register_static_routes
|
||||||
|
@ -18,6 +19,8 @@ from .worker import DevicePoolExecutor
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
set_start_method("spawn", force=True)
|
||||||
|
|
||||||
context = ServerContext.from_environ()
|
context = ServerContext.from_environ()
|
||||||
apply_patches(context)
|
apply_patches(context)
|
||||||
check_paths(context)
|
check_paths(context)
|
||||||
|
@ -42,12 +45,11 @@ def main():
|
||||||
register_static_routes(app, context, pool)
|
register_static_routes(app, context, pool)
|
||||||
register_api_routes(app, context, pool)
|
register_api_routes(app, context, pool)
|
||||||
|
|
||||||
return app #, context, pool
|
return app, pool
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# app, context, pool = main()
|
app, pool = main()
|
||||||
app = main()
|
|
||||||
app.run("0.0.0.0", 5000, debug=is_debug())
|
app.run("0.0.0.0", 5000, debug=is_debug())
|
||||||
# pool.join()
|
pool.join()
|
||||||
|
|
||||||
|
|
|
@ -93,6 +93,12 @@ class DevicePoolExecutor:
|
||||||
|
|
||||||
return lowest_devices[0]
|
return lowest_devices[0]
|
||||||
|
|
||||||
|
def join(self):
|
||||||
|
for device, worker in self.workers.items():
|
||||||
|
if worker.is_alive():
|
||||||
|
logger.info("stopping worker for device %s", device)
|
||||||
|
worker.join(5)
|
||||||
|
|
||||||
def prune(self):
|
def prune(self):
|
||||||
finished_count = len(self.finished)
|
finished_count = len(self.finished)
|
||||||
if finished_count > self.finished_limit:
|
if finished_count > self.finished_limit:
|
||||||
|
|
Loading…
Reference in New Issue