1
0
Fork 0

feat: add admin endpoint to restart image workers (#207)

This commit is contained in:
Sean Sube 2023-04-20 07:36:31 -05:00
parent e12f3c2801
commit df0e7dc57e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 102 additions and 29 deletions

View File

@ -23,7 +23,7 @@ if __name__ == '__main__':
# create the server and load the config
from onnx_web.main import main
app, pool = main()
server, app, pool = main()
# launch the image workers
print("starting image workers")
@ -39,7 +39,7 @@ if __name__ == '__main__':
# launch the user's web browser
print("opening web browser")
url = "http://127.0.0.1:5000"
webbrowser.open_new_tab(f"{url}?api={url}")
webbrowser.open_new_tab(f"{url}?api={url}&token={server.admin_token}")
# wait for enter and exit
input("press enter to quit")

View File

@ -12,10 +12,9 @@ from onnx import load_model, save_model
from transformers import CLIPTokenizer
from yaml import safe_load
from onnx_web.convert.diffusion.control import convert_diffusion_control
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control
from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.lora import blend_loras
from .diffusion.original import convert_diffusion_original

View File

@ -11,6 +11,7 @@ from huggingface_hub.utils.tqdm import disable_progress_bars
from setproctitle import setproctitle
from torch.multiprocessing import set_start_method
from .server.admin import register_admin_routes
from .server.api import register_api_routes
from .server.context import ServerContext
from .server.hacks import apply_patches
@ -37,55 +38,58 @@ def main():
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
context = ServerContext.from_environ()
apply_patches(context)
check_paths(context)
load_extras(context)
load_models(context)
load_params(context)
load_platforms(context)
# launch server, read env and list paths
server = ServerContext.from_environ()
apply_patches(server)
check_paths(server)
load_extras(server)
load_models(server)
load_params(server)
load_platforms(server)
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
if not context.show_progress:
if not server.show_progress:
disable_progress_bar()
disable_progress_bars()
# create workers
# any is a fake device and should not be in the pool
pool = DevicePoolExecutor(
context, [p for p in get_available_platforms() if p.device != "any"]
server, [p for p in get_available_platforms() if p.device != "any"]
)
# create server
app = Flask(__name__)
CORS(app, origins=context.cors_origin)
CORS(app, origins=server.cors_origin)
# register routes
register_static_routes(app, context, pool)
register_api_routes(app, context, pool)
register_static_routes(app, server, pool)
register_api_routes(app, server, pool)
register_admin_routes(app, server, pool)
return app, pool
return server, app, pool
def run():
app, pool = main()
_server, app, pool = main()
pool.start()
def quit(p: DevicePoolExecutor):
logger.info("shutting down workers")
p.join()
# TODO: print admin token
atexit.register(partial(quit, pool))
return app
if __name__ == "__main__":
app, pool = main()
server, app, pool = main()
logger.info("starting image workers")
pool.start()
logger.info("starting API server")
logger.info("starting API server with admin token: %s", server.admin_token)
app.run("0.0.0.0", 5000, debug=is_debug())
logger.info("shutting down workers")
pool.join()

View File

@ -0,0 +1,36 @@
from logging import getLogger
from flask import Flask, jsonify, make_response, request
from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
from .utils import wrap_route
logger = getLogger(__name__)
def check_admin(server: ServerContext):
return request.args.get("token", None) == server.admin_token
def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
if not check_admin(server):
return make_response(jsonify({})), 401
logger.info("restarting worker pool")
pool.join()
pool.start()
logger.info("restarted worker pool")
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status())
def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api/restart", methods=["POST"])(
wrap_route(restart_workers, server, pool=pool)
),
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)),
]

View File

@ -523,10 +523,6 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
)
def status(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status())
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api")(wrap_route(introspect, server, app=app)),
@ -560,5 +556,4 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
wrap_route(cancel, server, pool=pool)
),
app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
app.route("/api/status")(wrap_route(status, server, pool=pool)),
]

View File

@ -1,5 +1,6 @@
from logging import getLogger
from os import environ, path
from secrets import token_urlsafe
from typing import List, Optional
import torch
@ -33,6 +34,7 @@ class ServerContext:
extra_models: Optional[List[str]] = None,
job_limit: int = DEFAULT_JOB_LIMIT,
memory_limit: Optional[int] = None,
admin_token: Optional[str] = None,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
@ -50,6 +52,7 @@ class ServerContext:
self.extra_models = extra_models or []
self.job_limit = job_limit
self.memory_limit = memory_limit
self.admin_token = admin_token or token_urlsafe()
self.cache = ModelCache(self.cache_limit)

View File

@ -334,13 +334,29 @@ export interface ApiClient {
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
/**
* Check whether some pipeline's output is ready yet.
* Check whether job has finished and its output is ready.
*/
ready(key: string): Promise<ReadyResponse>;
/**
* Cancel an existing job.
*/
cancel(key: string): Promise<boolean>;
/**
* Retry a previous job using the same parameters.
*/
retry(params: RetryParams): Promise<ImageResponseWithRetry>;
/**
* Restart the image job workers.
*/
restart(token: string): Promise<boolean>;
/**
* Check the status of the image job workers.
*/
status(token: string): Promise<Array<unknown>>;
}
/**
@ -752,7 +768,23 @@ export function makeClient(root: string, f = fetch): ApiClient {
default:
throw new InvalidArgumentError('unknown request type');
}
}
},
async restart(token: string): Promise<boolean> {
const path = makeApiUrl(root, 'restart');
path.searchParams.append('token', token);
const res = await f(path, {
method: 'POST',
});
return res.status === STATUS_SUCCESS;
},
async status(token: string): Promise<Array<unknown>> {
const path = makeApiUrl(root, 'status');
path.searchParams.append('token', token);
const res = await f(path);
return res.json();
},
};
}

View File

@ -1,6 +1,6 @@
import { mustExist } from '@apextoaster/js-utils';
import { Stack } from '@mui/material';
import { useQuery } from '@tanstack/react-query';
import { Button, Stack } from '@mui/material';
import { useMutation, useQuery } from '@tanstack/react-query';
import * as React from 'react';
import { useContext } from 'react';
import { useTranslation } from 'react-i18next';
@ -21,6 +21,9 @@ export function ModelControl() {
const setModel = useStore(state, (s) => s.setModel);
const { t } = useTranslation();
const token = '';
const restart = useMutation(['restart'], async () => client.restart(token));
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
@ -159,6 +162,7 @@ export function ModelControl() {
addToken('lora', name);
}}
/>
<Button onClick={() => restart.mutate()}>Restart</Button>
</Stack>
</Stack>;
}