2023-04-20 12:36:31 +00:00
|
|
|
from logging import getLogger
|
|
|
|
|
|
|
|
from flask import Flask, jsonify, make_response, request
|
2023-05-06 21:49:02 +00:00
|
|
|
from jsonschema import ValidationError, validate
|
2023-04-20 12:36:31 +00:00
|
|
|
|
2023-05-06 21:49:02 +00:00
|
|
|
from ..utils import load_config, load_config_str
|
2023-04-20 12:36:31 +00:00
|
|
|
from ..worker.pool import DevicePoolExecutor
|
|
|
|
from .context import ServerContext
|
2023-05-09 02:59:27 +00:00
|
|
|
from .load import load_extras, load_models
|
2023-04-20 12:36:31 +00:00
|
|
|
from .utils import wrap_route
|
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-05-10 13:17:22 +00:00
|
|
|
conversion_lock = False
|
|
|
|
|
2023-04-20 12:36:31 +00:00
|
|
|
|
|
|
|
def check_admin(server: ServerContext):
|
|
|
|
return request.args.get("token", None) == server.admin_token
|
|
|
|
|
|
|
|
|
|
|
|
def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
if not check_admin(server):
|
|
|
|
return make_response(jsonify({})), 401
|
|
|
|
|
|
|
|
logger.info("restarting worker pool")
|
2023-04-20 22:36:29 +00:00
|
|
|
pool.recycle(recycle_all=True)
|
2023-04-20 12:36:31 +00:00
|
|
|
logger.info("restarted worker pool")
|
|
|
|
|
2023-04-20 22:36:29 +00:00
|
|
|
return jsonify(pool.status())
|
|
|
|
|
2023-04-20 12:36:31 +00:00
|
|
|
|
|
|
|
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
|
2023-05-06 21:49:02 +00:00
|
|
|
if not check_admin(server):
|
|
|
|
return make_response(jsonify({})), 401
|
|
|
|
|
2023-04-20 12:36:31 +00:00
|
|
|
return jsonify(pool.status())
|
|
|
|
|
|
|
|
|
2023-05-06 21:49:02 +00:00
|
|
|
def get_extra_models(server: ServerContext):
|
|
|
|
if not check_admin(server):
|
|
|
|
return make_response(jsonify({})), 401
|
|
|
|
|
|
|
|
with open(server.extra_models[0]) as f:
|
|
|
|
resp = make_response(f.read())
|
|
|
|
resp.content_type = "application/json"
|
|
|
|
return resp
|
|
|
|
|
|
|
|
|
|
|
|
def update_extra_models(server: ServerContext):
|
2023-05-10 13:17:22 +00:00
|
|
|
global conversion_lock
|
|
|
|
|
2023-05-06 21:49:02 +00:00
|
|
|
if not check_admin(server):
|
|
|
|
return make_response(jsonify({})), 401
|
|
|
|
|
2023-05-10 13:17:22 +00:00
|
|
|
if conversion_lock:
|
|
|
|
return make_response(jsonify({})), 409
|
|
|
|
|
2023-05-06 21:49:02 +00:00
|
|
|
extra_schema = load_config("./schemas/extras.yaml")
|
2023-05-09 02:56:05 +00:00
|
|
|
data_str = request.data.decode(encoding=(request.content_encoding or "utf-8"))
|
2023-05-06 21:49:02 +00:00
|
|
|
|
|
|
|
try:
|
2023-05-09 02:49:36 +00:00
|
|
|
data = load_config_str(data_str)
|
2023-05-06 21:49:02 +00:00
|
|
|
try:
|
|
|
|
validate(data, extra_schema)
|
|
|
|
except ValidationError:
|
|
|
|
logger.exception("invalid data in extras file")
|
|
|
|
except Exception:
|
2023-05-10 13:17:22 +00:00
|
|
|
logger.exception("error validating extras file")
|
2023-05-06 21:49:02 +00:00
|
|
|
|
2023-05-09 02:40:51 +00:00
|
|
|
# TODO: make a backup
|
|
|
|
with open(server.extra_models[0], mode="w") as f:
|
2023-05-09 02:49:36 +00:00
|
|
|
f.write(data_str)
|
2023-05-09 02:40:51 +00:00
|
|
|
|
|
|
|
logger.warning("downloading and converting models to ONNX")
|
2023-05-10 13:17:22 +00:00
|
|
|
conversion_lock = True
|
|
|
|
|
|
|
|
from onnx_web.convert.__main__ import main as convert
|
2023-05-15 00:30:30 +00:00
|
|
|
|
2023-05-09 02:59:27 +00:00
|
|
|
convert(
|
|
|
|
args=[
|
|
|
|
"--correction",
|
|
|
|
"--diffusion",
|
|
|
|
"--upscaling",
|
|
|
|
"--extras",
|
|
|
|
*server.extra_models,
|
|
|
|
]
|
|
|
|
)
|
2023-05-10 13:17:22 +00:00
|
|
|
|
|
|
|
logger.info("finished converting models, reloading server")
|
2023-05-09 02:59:27 +00:00
|
|
|
load_models(server)
|
|
|
|
load_extras(server)
|
2023-05-09 02:40:51 +00:00
|
|
|
|
2023-05-10 13:17:22 +00:00
|
|
|
conversion_lock = False
|
|
|
|
|
2023-05-09 02:40:51 +00:00
|
|
|
return jsonify(data)
|
2023-05-06 21:49:02 +00:00
|
|
|
|
|
|
|
|
2023-04-20 12:36:31 +00:00
|
|
|
def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
|
|
|
return [
|
2023-05-06 21:49:02 +00:00
|
|
|
app.route("/api/extras")(wrap_route(get_extra_models, server)),
|
|
|
|
app.route("/api/extras", methods=["PUT"])(
|
|
|
|
wrap_route(update_extra_models, server)
|
|
|
|
),
|
2023-04-20 12:36:31 +00:00
|
|
|
app.route("/api/restart", methods=["POST"])(
|
|
|
|
wrap_route(restart_workers, server, pool=pool)
|
|
|
|
),
|
|
|
|
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)),
|
|
|
|
]
|