1
0
Fork 0

feat: add admin endpoint to restart image workers (#207)

This commit is contained in:
Sean Sube 2023-04-20 07:36:31 -05:00
parent e12f3c2801
commit df0e7dc57e
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
8 changed files with 102 additions and 29 deletions

View File

@ -23,7 +23,7 @@ if __name__ == '__main__':
# create the server and load the config # create the server and load the config
from onnx_web.main import main from onnx_web.main import main
app, pool = main() server, app, pool = main()
# launch the image workers # launch the image workers
print("starting image workers") print("starting image workers")
@ -39,7 +39,7 @@ if __name__ == '__main__':
# launch the user's web browser # launch the user's web browser
print("opening web browser") print("opening web browser")
url = "http://127.0.0.1:5000" url = "http://127.0.0.1:5000"
webbrowser.open_new_tab(f"{url}?api={url}") webbrowser.open_new_tab(f"{url}?api={url}&token={server.admin_token}")
# wait for enter and exit # wait for enter and exit
input("press enter to quit") input("press enter to quit")

View File

@ -12,10 +12,9 @@ from onnx import load_model, save_model
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from yaml import safe_load from yaml import safe_load
from onnx_web.convert.diffusion.control import convert_diffusion_control
from ..constants import ONNX_MODEL, ONNX_WEIGHTS from ..constants import ONNX_MODEL, ONNX_WEIGHTS
from .correction.gfpgan import convert_correction_gfpgan from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control
from .diffusion.diffusers import convert_diffusion_diffusers from .diffusion.diffusers import convert_diffusion_diffusers
from .diffusion.lora import blend_loras from .diffusion.lora import blend_loras
from .diffusion.original import convert_diffusion_original from .diffusion.original import convert_diffusion_original

View File

@ -11,6 +11,7 @@ from huggingface_hub.utils.tqdm import disable_progress_bars
from setproctitle import setproctitle from setproctitle import setproctitle
from torch.multiprocessing import set_start_method from torch.multiprocessing import set_start_method
from .server.admin import register_admin_routes
from .server.api import register_api_routes from .server.api import register_api_routes
from .server.context import ServerContext from .server.context import ServerContext
from .server.hacks import apply_patches from .server.hacks import apply_patches
@ -37,55 +38,58 @@ def main():
mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css") mimetypes.add_type("text/css", ".css")
context = ServerContext.from_environ() # launch server, read env and list paths
apply_patches(context) server = ServerContext.from_environ()
check_paths(context) apply_patches(server)
load_extras(context) check_paths(server)
load_models(context) load_extras(server)
load_params(context) load_models(server)
load_platforms(context) load_params(server)
load_platforms(server)
if is_debug(): if is_debug():
gc.set_debug(gc.DEBUG_STATS) gc.set_debug(gc.DEBUG_STATS)
if not context.show_progress: if not server.show_progress:
disable_progress_bar() disable_progress_bar()
disable_progress_bars() disable_progress_bars()
# create workers # create workers
# any is a fake device and should not be in the pool # any is a fake device and should not be in the pool
pool = DevicePoolExecutor( pool = DevicePoolExecutor(
context, [p for p in get_available_platforms() if p.device != "any"] server, [p for p in get_available_platforms() if p.device != "any"]
) )
# create server # create server
app = Flask(__name__) app = Flask(__name__)
CORS(app, origins=context.cors_origin) CORS(app, origins=server.cors_origin)
# register routes # register routes
register_static_routes(app, context, pool) register_static_routes(app, server, pool)
register_api_routes(app, context, pool) register_api_routes(app, server, pool)
register_admin_routes(app, server, pool)
return app, pool return server, app, pool
def run(): def run():
app, pool = main() _server, app, pool = main()
pool.start() pool.start()
def quit(p: DevicePoolExecutor): def quit(p: DevicePoolExecutor):
logger.info("shutting down workers") logger.info("shutting down workers")
p.join() p.join()
# TODO: print admin token
atexit.register(partial(quit, pool)) atexit.register(partial(quit, pool))
return app return app
if __name__ == "__main__": if __name__ == "__main__":
app, pool = main() server, app, pool = main()
logger.info("starting image workers") logger.info("starting image workers")
pool.start() pool.start()
logger.info("starting API server") logger.info("starting API server with admin token: %s", server.admin_token)
app.run("0.0.0.0", 5000, debug=is_debug()) app.run("0.0.0.0", 5000, debug=is_debug())
logger.info("shutting down workers") logger.info("shutting down workers")
pool.join() pool.join()

View File

@ -0,0 +1,36 @@
from logging import getLogger
from flask import Flask, jsonify, make_response, request
from ..worker.pool import DevicePoolExecutor
from .context import ServerContext
from .utils import wrap_route
logger = getLogger(__name__)
def check_admin(server: ServerContext):
return request.args.get("token", None) == server.admin_token
def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
if not check_admin(server):
return make_response(jsonify({})), 401
logger.info("restarting worker pool")
pool.join()
pool.start()
logger.info("restarted worker pool")
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status())
def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [
app.route("/api/restart", methods=["POST"])(
wrap_route(restart_workers, server, pool=pool)
),
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)),
]

View File

@ -523,10 +523,6 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
) )
def status(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status())
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor): def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
return [ return [
app.route("/api")(wrap_route(introspect, server, app=app)), app.route("/api")(wrap_route(introspect, server, app=app)),
@ -560,5 +556,4 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
wrap_route(cancel, server, pool=pool) wrap_route(cancel, server, pool=pool)
), ),
app.route("/api/ready")(wrap_route(ready, server, pool=pool)), app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
app.route("/api/status")(wrap_route(status, server, pool=pool)),
] ]

View File

@ -1,5 +1,6 @@
from logging import getLogger from logging import getLogger
from os import environ, path from os import environ, path
from secrets import token_urlsafe
from typing import List, Optional from typing import List, Optional
import torch import torch
@ -33,6 +34,7 @@ class ServerContext:
extra_models: Optional[List[str]] = None, extra_models: Optional[List[str]] = None,
job_limit: int = DEFAULT_JOB_LIMIT, job_limit: int = DEFAULT_JOB_LIMIT,
memory_limit: Optional[int] = None, memory_limit: Optional[int] = None,
admin_token: Optional[str] = None,
) -> None: ) -> None:
self.bundle_path = bundle_path self.bundle_path = bundle_path
self.model_path = model_path self.model_path = model_path
@ -50,6 +52,7 @@ class ServerContext:
self.extra_models = extra_models or [] self.extra_models = extra_models or []
self.job_limit = job_limit self.job_limit = job_limit
self.memory_limit = memory_limit self.memory_limit = memory_limit
self.admin_token = admin_token or token_urlsafe()
self.cache = ModelCache(self.cache_limit) self.cache = ModelCache(self.cache_limit)

View File

@ -334,13 +334,29 @@ export interface ApiClient {
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>; blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
/** /**
* Check whether some pipeline's output is ready yet. * Check whether job has finished and its output is ready.
*/ */
ready(key: string): Promise<ReadyResponse>; ready(key: string): Promise<ReadyResponse>;
/**
* Cancel an existing job.
*/
cancel(key: string): Promise<boolean>; cancel(key: string): Promise<boolean>;
/**
* Retry a previous job using the same parameters.
*/
retry(params: RetryParams): Promise<ImageResponseWithRetry>; retry(params: RetryParams): Promise<ImageResponseWithRetry>;
/**
* Restart the image job workers.
*/
restart(token: string): Promise<boolean>;
/**
* Check the status of the image job workers.
*/
status(token: string): Promise<Array<unknown>>;
} }
/** /**
@ -752,7 +768,23 @@ export function makeClient(root: string, f = fetch): ApiClient {
default: default:
throw new InvalidArgumentError('unknown request type'); throw new InvalidArgumentError('unknown request type');
} }
} },
async restart(token: string): Promise<boolean> {
const path = makeApiUrl(root, 'restart');
path.searchParams.append('token', token);
const res = await f(path, {
method: 'POST',
});
return res.status === STATUS_SUCCESS;
},
async status(token: string): Promise<Array<unknown>> {
const path = makeApiUrl(root, 'status');
path.searchParams.append('token', token);
const res = await f(path);
return res.json();
},
}; };
} }

View File

@ -1,6 +1,6 @@
import { mustExist } from '@apextoaster/js-utils'; import { mustExist } from '@apextoaster/js-utils';
import { Stack } from '@mui/material'; import { Button, Stack } from '@mui/material';
import { useQuery } from '@tanstack/react-query'; import { useMutation, useQuery } from '@tanstack/react-query';
import * as React from 'react'; import * as React from 'react';
import { useContext } from 'react'; import { useContext } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -21,6 +21,9 @@ export function ModelControl() {
const setModel = useStore(state, (s) => s.setModel); const setModel = useStore(state, (s) => s.setModel);
const { t } = useTranslation(); const { t } = useTranslation();
const token = '';
const restart = useMutation(['restart'], async () => client.restart(token));
const models = useQuery(['models'], async () => client.models(), { const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
@ -159,6 +162,7 @@ export function ModelControl() {
addToken('lora', name); addToken('lora', name);
}} }}
/> />
<Button onClick={() => restart.mutate()}>Restart</Button>
</Stack> </Stack>
</Stack>; </Stack>;
} }