1
0
Fork 0
onnx-web/api/onnx_web/main.py

104 lines
2.6 KiB
Python
Raw Normal View History

2023-02-26 21:21:58 +00:00
import atexit
import gc
2023-04-15 18:18:24 +00:00
import mimetypes
from functools import partial
2023-02-26 21:21:58 +00:00
from logging import getLogger
from diffusers.utils.logging import disable_progress_bar
from flask import Flask
from flask_cors import CORS
from huggingface_hub.utils.tqdm import disable_progress_bars
from setproctitle import setproctitle
2023-02-26 16:47:31 +00:00
from torch.multiprocessing import set_start_method
from .server.admin import register_admin_routes
from .server.api import register_api_routes
2023-03-05 13:19:48 +00:00
from .server.context import ServerContext
from .server.hacks import apply_patches
from .server.load import (
2023-02-26 20:15:30 +00:00
get_available_platforms,
2023-03-05 05:09:56 +00:00
load_extras,
2023-02-26 20:15:30 +00:00
load_models,
load_params,
load_platforms,
)
from .server.static import register_static_routes
from .server.utils import check_paths
from .utils import is_debug
from .worker import DevicePoolExecutor
2023-02-26 21:21:58 +00:00
logger = getLogger(__name__)
def main():
setproctitle("onnx-web server")
2023-02-26 20:15:30 +00:00
set_start_method("spawn", force=True)
2023-02-26 16:47:31 +00:00
2023-04-15 18:18:24 +00:00
# set up missing mimetypes
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
# launch server, read env and list paths
server = ServerContext.from_environ()
apply_patches(server)
check_paths(server)
load_extras(server)
load_models(server)
load_params(server)
load_platforms(server)
2023-02-26 20:15:30 +00:00
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
if not server.show_progress:
2023-02-26 20:15:30 +00:00
disable_progress_bar()
disable_progress_bars()
# create workers
# any is a fake device and should not be in the pool
2023-02-26 20:15:30 +00:00
pool = DevicePoolExecutor(
server, [p for p in get_available_platforms() if p.device != "any"]
2023-02-26 20:15:30 +00:00
)
# create server
app = Flask(__name__)
CORS(app, origins=server.cors_origin)
2023-02-26 20:15:30 +00:00
# register routes
register_static_routes(app, server, pool)
register_api_routes(app, server, pool)
register_admin_routes(app, server, pool)
return server, app, pool
2023-02-26 21:21:58 +00:00
def run():
server, app, pool = main()
pool.start()
def quit(p: DevicePoolExecutor):
logger.info("shutting down workers")
p.join()
logger.info(
"starting %s API server with admin token: %s",
server.server_version,
server.admin_token,
)
atexit.register(partial(quit, pool))
2023-02-26 21:21:58 +00:00
return app
if __name__ == "__main__":
server, app, pool = main()
logger.info("starting image workers")
pool.start()
logger.info(
"starting %s API server with admin token: %s",
server.server_version,
server.admin_token,
)
2023-02-26 20:15:30 +00:00
app.run("0.0.0.0", 5000, debug=is_debug())
logger.info("shutting down workers")
2023-02-28 05:03:42 +00:00
pool.join()