2023-02-26 21:21:58 +00:00
|
|
|
import atexit
|
2023-02-26 16:15:12 +00:00
|
|
|
import gc
|
2023-02-26 21:21:58 +00:00
|
|
|
from logging import getLogger
|
2023-02-26 16:15:12 +00:00
|
|
|
|
|
|
|
from diffusers.utils.logging import disable_progress_bar
|
|
|
|
from flask import Flask
|
|
|
|
from flask_cors import CORS
|
|
|
|
from huggingface_hub.utils.tqdm import disable_progress_bars
|
2023-02-26 16:47:31 +00:00
|
|
|
from torch.multiprocessing import set_start_method
|
2023-02-26 16:15:12 +00:00
|
|
|
|
|
|
|
from .server.api import register_api_routes
|
2023-02-26 20:15:30 +00:00
|
|
|
from .server.config import (
|
|
|
|
get_available_platforms,
|
|
|
|
load_models,
|
|
|
|
load_params,
|
|
|
|
load_platforms,
|
|
|
|
)
|
2023-02-26 16:15:12 +00:00
|
|
|
from .server.context import ServerContext
|
|
|
|
from .server.hacks import apply_patches
|
2023-02-26 20:15:30 +00:00
|
|
|
from .server.static import register_static_routes
|
|
|
|
from .server.utils import check_paths
|
|
|
|
from .utils import is_debug
|
2023-02-26 16:15:12 +00:00
|
|
|
from .worker import DevicePoolExecutor
|
|
|
|
|
2023-02-26 21:21:58 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-02-26 16:15:12 +00:00
|
|
|
|
|
|
|
def main():
|
2023-02-26 20:15:30 +00:00
|
|
|
set_start_method("spawn", force=True)
|
2023-02-26 16:47:31 +00:00
|
|
|
|
2023-02-26 20:15:30 +00:00
|
|
|
context = ServerContext.from_environ()
|
|
|
|
apply_patches(context)
|
|
|
|
check_paths(context)
|
|
|
|
load_models(context)
|
|
|
|
load_params(context)
|
|
|
|
load_platforms(context)
|
2023-02-26 16:15:12 +00:00
|
|
|
|
2023-02-26 20:15:30 +00:00
|
|
|
if is_debug():
|
|
|
|
gc.set_debug(gc.DEBUG_STATS)
|
2023-02-26 16:15:12 +00:00
|
|
|
|
2023-02-26 20:15:30 +00:00
|
|
|
if not context.show_progress:
|
|
|
|
disable_progress_bar()
|
|
|
|
disable_progress_bars()
|
2023-02-26 16:15:12 +00:00
|
|
|
|
2023-02-26 20:15:30 +00:00
|
|
|
app = Flask(__name__)
|
|
|
|
CORS(app, origins=context.cors_origin)
|
2023-02-26 16:15:12 +00:00
|
|
|
|
2023-02-26 20:15:30 +00:00
|
|
|
# any is a fake device, should not be in the pool
|
|
|
|
pool = DevicePoolExecutor(
|
|
|
|
context, [p for p in get_available_platforms() if p.device != "any"]
|
|
|
|
)
|
2023-02-26 16:15:12 +00:00
|
|
|
|
2023-02-26 20:15:30 +00:00
|
|
|
# register routes
|
|
|
|
register_static_routes(app, context, pool)
|
|
|
|
register_api_routes(app, context, pool)
|
2023-02-26 16:15:12 +00:00
|
|
|
|
2023-02-26 20:15:30 +00:00
|
|
|
return app, pool
|
2023-02-26 16:15:12 +00:00
|
|
|
|
|
|
|
|
2023-02-26 21:21:58 +00:00
|
|
|
def run():
|
2023-02-26 20:15:30 +00:00
|
|
|
app, pool = main()
|
2023-02-28 05:01:26 +00:00
|
|
|
|
|
|
|
def quit():
|
|
|
|
logger.info("shutting down workers")
|
|
|
|
pool.join()
|
|
|
|
|
|
|
|
atexit.register(quit)
|
2023-02-26 21:21:58 +00:00
|
|
|
return app
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-02-28 05:03:42 +00:00
|
|
|
app, pool = main()
|
2023-02-26 20:15:30 +00:00
|
|
|
app.run("0.0.0.0", 5000, debug=is_debug())
|
2023-02-26 21:21:58 +00:00
|
|
|
logger.info("shutting down app")
|
2023-02-28 05:03:42 +00:00
|
|
|
pool.join()
|