feat: add admin endpoint to restart image workers (#207)
This commit is contained in:
parent
e12f3c2801
commit
df0e7dc57e
|
@ -23,7 +23,7 @@ if __name__ == '__main__':
|
|||
|
||||
# create the server and load the config
|
||||
from onnx_web.main import main
|
||||
app, pool = main()
|
||||
server, app, pool = main()
|
||||
|
||||
# launch the image workers
|
||||
print("starting image workers")
|
||||
|
@ -39,7 +39,7 @@ if __name__ == '__main__':
|
|||
# launch the user's web browser
|
||||
print("opening web browser")
|
||||
url = "http://127.0.0.1:5000"
|
||||
webbrowser.open_new_tab(f"{url}?api={url}")
|
||||
webbrowser.open_new_tab(f"{url}?api={url}&token={server.admin_token}")
|
||||
|
||||
# wait for enter and exit
|
||||
input("press enter to quit")
|
||||
|
|
|
@ -12,10 +12,9 @@ from onnx import load_model, save_model
|
|||
from transformers import CLIPTokenizer
|
||||
from yaml import safe_load
|
||||
|
||||
from onnx_web.convert.diffusion.control import convert_diffusion_control
|
||||
|
||||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||
from .correction.gfpgan import convert_correction_gfpgan
|
||||
from .diffusion.control import convert_diffusion_control
|
||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||
from .diffusion.lora import blend_loras
|
||||
from .diffusion.original import convert_diffusion_original
|
||||
|
|
|
@ -11,6 +11,7 @@ from huggingface_hub.utils.tqdm import disable_progress_bars
|
|||
from setproctitle import setproctitle
|
||||
from torch.multiprocessing import set_start_method
|
||||
|
||||
from .server.admin import register_admin_routes
|
||||
from .server.api import register_api_routes
|
||||
from .server.context import ServerContext
|
||||
from .server.hacks import apply_patches
|
||||
|
@ -37,55 +38,58 @@ def main():
|
|||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
context = ServerContext.from_environ()
|
||||
apply_patches(context)
|
||||
check_paths(context)
|
||||
load_extras(context)
|
||||
load_models(context)
|
||||
load_params(context)
|
||||
load_platforms(context)
|
||||
# launch server, read env and list paths
|
||||
server = ServerContext.from_environ()
|
||||
apply_patches(server)
|
||||
check_paths(server)
|
||||
load_extras(server)
|
||||
load_models(server)
|
||||
load_params(server)
|
||||
load_platforms(server)
|
||||
|
||||
if is_debug():
|
||||
gc.set_debug(gc.DEBUG_STATS)
|
||||
|
||||
if not context.show_progress:
|
||||
if not server.show_progress:
|
||||
disable_progress_bar()
|
||||
disable_progress_bars()
|
||||
|
||||
# create workers
|
||||
# any is a fake device and should not be in the pool
|
||||
pool = DevicePoolExecutor(
|
||||
context, [p for p in get_available_platforms() if p.device != "any"]
|
||||
server, [p for p in get_available_platforms() if p.device != "any"]
|
||||
)
|
||||
|
||||
# create server
|
||||
app = Flask(__name__)
|
||||
CORS(app, origins=context.cors_origin)
|
||||
CORS(app, origins=server.cors_origin)
|
||||
|
||||
# register routes
|
||||
register_static_routes(app, context, pool)
|
||||
register_api_routes(app, context, pool)
|
||||
register_static_routes(app, server, pool)
|
||||
register_api_routes(app, server, pool)
|
||||
register_admin_routes(app, server, pool)
|
||||
|
||||
return app, pool
|
||||
return server, app, pool
|
||||
|
||||
|
||||
def run():
|
||||
app, pool = main()
|
||||
_server, app, pool = main()
|
||||
pool.start()
|
||||
|
||||
def quit(p: DevicePoolExecutor):
|
||||
logger.info("shutting down workers")
|
||||
p.join()
|
||||
|
||||
# TODO: print admin token
|
||||
atexit.register(partial(quit, pool))
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app, pool = main()
|
||||
server, app, pool = main()
|
||||
logger.info("starting image workers")
|
||||
pool.start()
|
||||
logger.info("starting API server")
|
||||
logger.info("starting API server with admin token: %s", server.admin_token)
|
||||
app.run("0.0.0.0", 5000, debug=is_debug())
|
||||
logger.info("shutting down workers")
|
||||
pool.join()
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
from logging import getLogger
|
||||
|
||||
from flask import Flask, jsonify, make_response, request
|
||||
|
||||
from ..worker.pool import DevicePoolExecutor
|
||||
from .context import ServerContext
|
||||
from .utils import wrap_route
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def check_admin(server: ServerContext):
|
||||
return request.args.get("token", None) == server.admin_token
|
||||
|
||||
|
||||
def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
|
||||
if not check_admin(server):
|
||||
return make_response(jsonify({})), 401
|
||||
|
||||
logger.info("restarting worker pool")
|
||||
pool.join()
|
||||
pool.start()
|
||||
logger.info("restarted worker pool")
|
||||
|
||||
|
||||
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||
return jsonify(pool.status())
|
||||
|
||||
|
||||
def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
||||
return [
|
||||
app.route("/api/restart", methods=["POST"])(
|
||||
wrap_route(restart_workers, server, pool=pool)
|
||||
),
|
||||
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)),
|
||||
]
|
|
@ -523,10 +523,6 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
|||
)
|
||||
|
||||
|
||||
def status(server: ServerContext, pool: DevicePoolExecutor):
|
||||
return jsonify(pool.status())
|
||||
|
||||
|
||||
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
||||
return [
|
||||
app.route("/api")(wrap_route(introspect, server, app=app)),
|
||||
|
@ -560,5 +556,4 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
|
|||
wrap_route(cancel, server, pool=pool)
|
||||
),
|
||||
app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
|
||||
app.route("/api/status")(wrap_route(status, server, pool=pool)),
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from logging import getLogger
|
||||
from os import environ, path
|
||||
from secrets import token_urlsafe
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
@ -33,6 +34,7 @@ class ServerContext:
|
|||
extra_models: Optional[List[str]] = None,
|
||||
job_limit: int = DEFAULT_JOB_LIMIT,
|
||||
memory_limit: Optional[int] = None,
|
||||
admin_token: Optional[str] = None,
|
||||
) -> None:
|
||||
self.bundle_path = bundle_path
|
||||
self.model_path = model_path
|
||||
|
@ -50,6 +52,7 @@ class ServerContext:
|
|||
self.extra_models = extra_models or []
|
||||
self.job_limit = job_limit
|
||||
self.memory_limit = memory_limit
|
||||
self.admin_token = admin_token or token_urlsafe()
|
||||
|
||||
self.cache = ModelCache(self.cache_limit)
|
||||
|
||||
|
|
|
@ -334,13 +334,29 @@ export interface ApiClient {
|
|||
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Check whether some pipeline's output is ready yet.
|
||||
* Check whether job has finished and its output is ready.
|
||||
*/
|
||||
ready(key: string): Promise<ReadyResponse>;
|
||||
|
||||
/**
|
||||
* Cancel an existing job.
|
||||
*/
|
||||
cancel(key: string): Promise<boolean>;
|
||||
|
||||
/**
|
||||
* Retry a previous job using the same parameters.
|
||||
*/
|
||||
retry(params: RetryParams): Promise<ImageResponseWithRetry>;
|
||||
|
||||
/**
|
||||
* Restart the image job workers.
|
||||
*/
|
||||
restart(token: string): Promise<boolean>;
|
||||
|
||||
/**
|
||||
* Check the status of the image job workers.
|
||||
*/
|
||||
status(token: string): Promise<Array<unknown>>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -752,7 +768,23 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
|||
default:
|
||||
throw new InvalidArgumentError('unknown request type');
|
||||
}
|
||||
}
|
||||
},
|
||||
async restart(token: string): Promise<boolean> {
|
||||
const path = makeApiUrl(root, 'restart');
|
||||
path.searchParams.append('token', token);
|
||||
|
||||
const res = await f(path, {
|
||||
method: 'POST',
|
||||
});
|
||||
return res.status === STATUS_SUCCESS;
|
||||
},
|
||||
async status(token: string): Promise<Array<unknown>> {
|
||||
const path = makeApiUrl(root, 'status');
|
||||
path.searchParams.append('token', token);
|
||||
|
||||
const res = await f(path);
|
||||
return res.json();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import { mustExist } from '@apextoaster/js-utils';
|
||||
import { Stack } from '@mui/material';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import { Button, Stack } from '@mui/material';
|
||||
import { useMutation, useQuery } from '@tanstack/react-query';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
@ -21,6 +21,9 @@ export function ModelControl() {
|
|||
const setModel = useStore(state, (s) => s.setModel);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const token = '';
|
||||
|
||||
const restart = useMutation(['restart'], async () => client.restart(token));
|
||||
const models = useQuery(['models'], async () => client.models(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
@ -159,6 +162,7 @@ export function ModelControl() {
|
|||
addToken('lora', name);
|
||||
}}
|
||||
/>
|
||||
<Button onClick={() => restart.mutate()}>Restart</Button>
|
||||
</Stack>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue