feat: add admin endpoint to restart image workers (#207)
This commit is contained in:
parent
e12f3c2801
commit
df0e7dc57e
|
@ -23,7 +23,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# create the server and load the config
|
# create the server and load the config
|
||||||
from onnx_web.main import main
|
from onnx_web.main import main
|
||||||
app, pool = main()
|
server, app, pool = main()
|
||||||
|
|
||||||
# launch the image workers
|
# launch the image workers
|
||||||
print("starting image workers")
|
print("starting image workers")
|
||||||
|
@ -39,7 +39,7 @@ if __name__ == '__main__':
|
||||||
# launch the user's web browser
|
# launch the user's web browser
|
||||||
print("opening web browser")
|
print("opening web browser")
|
||||||
url = "http://127.0.0.1:5000"
|
url = "http://127.0.0.1:5000"
|
||||||
webbrowser.open_new_tab(f"{url}?api={url}")
|
webbrowser.open_new_tab(f"{url}?api={url}&token={server.admin_token}")
|
||||||
|
|
||||||
# wait for enter and exit
|
# wait for enter and exit
|
||||||
input("press enter to quit")
|
input("press enter to quit")
|
||||||
|
|
|
@ -12,10 +12,9 @@ from onnx import load_model, save_model
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
from onnx_web.convert.diffusion.control import convert_diffusion_control
|
|
||||||
|
|
||||||
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
from ..constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||||
from .correction.gfpgan import convert_correction_gfpgan
|
from .correction.gfpgan import convert_correction_gfpgan
|
||||||
|
from .diffusion.control import convert_diffusion_control
|
||||||
from .diffusion.diffusers import convert_diffusion_diffusers
|
from .diffusion.diffusers import convert_diffusion_diffusers
|
||||||
from .diffusion.lora import blend_loras
|
from .diffusion.lora import blend_loras
|
||||||
from .diffusion.original import convert_diffusion_original
|
from .diffusion.original import convert_diffusion_original
|
||||||
|
|
|
@ -11,6 +11,7 @@ from huggingface_hub.utils.tqdm import disable_progress_bars
|
||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
from torch.multiprocessing import set_start_method
|
from torch.multiprocessing import set_start_method
|
||||||
|
|
||||||
|
from .server.admin import register_admin_routes
|
||||||
from .server.api import register_api_routes
|
from .server.api import register_api_routes
|
||||||
from .server.context import ServerContext
|
from .server.context import ServerContext
|
||||||
from .server.hacks import apply_patches
|
from .server.hacks import apply_patches
|
||||||
|
@ -37,55 +38,58 @@ def main():
|
||||||
mimetypes.add_type("application/javascript", ".js")
|
mimetypes.add_type("application/javascript", ".js")
|
||||||
mimetypes.add_type("text/css", ".css")
|
mimetypes.add_type("text/css", ".css")
|
||||||
|
|
||||||
context = ServerContext.from_environ()
|
# launch server, read env and list paths
|
||||||
apply_patches(context)
|
server = ServerContext.from_environ()
|
||||||
check_paths(context)
|
apply_patches(server)
|
||||||
load_extras(context)
|
check_paths(server)
|
||||||
load_models(context)
|
load_extras(server)
|
||||||
load_params(context)
|
load_models(server)
|
||||||
load_platforms(context)
|
load_params(server)
|
||||||
|
load_platforms(server)
|
||||||
|
|
||||||
if is_debug():
|
if is_debug():
|
||||||
gc.set_debug(gc.DEBUG_STATS)
|
gc.set_debug(gc.DEBUG_STATS)
|
||||||
|
|
||||||
if not context.show_progress:
|
if not server.show_progress:
|
||||||
disable_progress_bar()
|
disable_progress_bar()
|
||||||
disable_progress_bars()
|
disable_progress_bars()
|
||||||
|
|
||||||
# create workers
|
# create workers
|
||||||
# any is a fake device and should not be in the pool
|
# any is a fake device and should not be in the pool
|
||||||
pool = DevicePoolExecutor(
|
pool = DevicePoolExecutor(
|
||||||
context, [p for p in get_available_platforms() if p.device != "any"]
|
server, [p for p in get_available_platforms() if p.device != "any"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# create server
|
# create server
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
CORS(app, origins=context.cors_origin)
|
CORS(app, origins=server.cors_origin)
|
||||||
|
|
||||||
# register routes
|
# register routes
|
||||||
register_static_routes(app, context, pool)
|
register_static_routes(app, server, pool)
|
||||||
register_api_routes(app, context, pool)
|
register_api_routes(app, server, pool)
|
||||||
|
register_admin_routes(app, server, pool)
|
||||||
|
|
||||||
return app, pool
|
return server, app, pool
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
app, pool = main()
|
_server, app, pool = main()
|
||||||
pool.start()
|
pool.start()
|
||||||
|
|
||||||
def quit(p: DevicePoolExecutor):
|
def quit(p: DevicePoolExecutor):
|
||||||
logger.info("shutting down workers")
|
logger.info("shutting down workers")
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
|
# TODO: print admin token
|
||||||
atexit.register(partial(quit, pool))
|
atexit.register(partial(quit, pool))
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app, pool = main()
|
server, app, pool = main()
|
||||||
logger.info("starting image workers")
|
logger.info("starting image workers")
|
||||||
pool.start()
|
pool.start()
|
||||||
logger.info("starting API server")
|
logger.info("starting API server with admin token: %s", server.admin_token)
|
||||||
app.run("0.0.0.0", 5000, debug=is_debug())
|
app.run("0.0.0.0", 5000, debug=is_debug())
|
||||||
logger.info("shutting down workers")
|
logger.info("shutting down workers")
|
||||||
pool.join()
|
pool.join()
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from flask import Flask, jsonify, make_response, request
|
||||||
|
|
||||||
|
from ..worker.pool import DevicePoolExecutor
|
||||||
|
from .context import ServerContext
|
||||||
|
from .utils import wrap_route
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def check_admin(server: ServerContext):
|
||||||
|
return request.args.get("token", None) == server.admin_token
|
||||||
|
|
||||||
|
|
||||||
|
def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
if not check_admin(server):
|
||||||
|
return make_response(jsonify({})), 401
|
||||||
|
|
||||||
|
logger.info("restarting worker pool")
|
||||||
|
pool.join()
|
||||||
|
pool.start()
|
||||||
|
logger.info("restarted worker pool")
|
||||||
|
|
||||||
|
|
||||||
|
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
return jsonify(pool.status())
|
||||||
|
|
||||||
|
|
||||||
|
def register_admin_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
||||||
|
return [
|
||||||
|
app.route("/api/restart", methods=["POST"])(
|
||||||
|
wrap_route(restart_workers, server, pool=pool)
|
||||||
|
),
|
||||||
|
app.route("/api/status")(wrap_route(worker_status, server, pool=pool)),
|
||||||
|
]
|
|
@ -523,10 +523,6 @@ def ready(server: ServerContext, pool: DevicePoolExecutor):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def status(server: ServerContext, pool: DevicePoolExecutor):
|
|
||||||
return jsonify(pool.status())
|
|
||||||
|
|
||||||
|
|
||||||
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecutor):
|
||||||
return [
|
return [
|
||||||
app.route("/api")(wrap_route(introspect, server, app=app)),
|
app.route("/api")(wrap_route(introspect, server, app=app)),
|
||||||
|
@ -560,5 +556,4 @@ def register_api_routes(app: Flask, server: ServerContext, pool: DevicePoolExecu
|
||||||
wrap_route(cancel, server, pool=pool)
|
wrap_route(cancel, server, pool=pool)
|
||||||
),
|
),
|
||||||
app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
|
app.route("/api/ready")(wrap_route(ready, server, pool=pool)),
|
||||||
app.route("/api/status")(wrap_route(status, server, pool=pool)),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os import environ, path
|
from os import environ, path
|
||||||
|
from secrets import token_urlsafe
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -33,6 +34,7 @@ class ServerContext:
|
||||||
extra_models: Optional[List[str]] = None,
|
extra_models: Optional[List[str]] = None,
|
||||||
job_limit: int = DEFAULT_JOB_LIMIT,
|
job_limit: int = DEFAULT_JOB_LIMIT,
|
||||||
memory_limit: Optional[int] = None,
|
memory_limit: Optional[int] = None,
|
||||||
|
admin_token: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bundle_path = bundle_path
|
self.bundle_path = bundle_path
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
@ -50,6 +52,7 @@ class ServerContext:
|
||||||
self.extra_models = extra_models or []
|
self.extra_models = extra_models or []
|
||||||
self.job_limit = job_limit
|
self.job_limit = job_limit
|
||||||
self.memory_limit = memory_limit
|
self.memory_limit = memory_limit
|
||||||
|
self.admin_token = admin_token or token_urlsafe()
|
||||||
|
|
||||||
self.cache = ModelCache(self.cache_limit)
|
self.cache = ModelCache(self.cache_limit)
|
||||||
|
|
||||||
|
|
|
@ -334,13 +334,29 @@ export interface ApiClient {
|
||||||
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
|
blend(model: ModelParams, params: BlendParams, upscale?: UpscaleParams): Promise<ImageResponseWithRetry>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Check whether some pipeline's output is ready yet.
|
* Check whether job has finished and its output is ready.
|
||||||
*/
|
*/
|
||||||
ready(key: string): Promise<ReadyResponse>;
|
ready(key: string): Promise<ReadyResponse>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cancel an existing job.
|
||||||
|
*/
|
||||||
cancel(key: string): Promise<boolean>;
|
cancel(key: string): Promise<boolean>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retry a previous job using the same parameters.
|
||||||
|
*/
|
||||||
retry(params: RetryParams): Promise<ImageResponseWithRetry>;
|
retry(params: RetryParams): Promise<ImageResponseWithRetry>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Restart the image job workers.
|
||||||
|
*/
|
||||||
|
restart(token: string): Promise<boolean>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check the status of the image job workers.
|
||||||
|
*/
|
||||||
|
status(token: string): Promise<Array<unknown>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -752,7 +768,23 @@ export function makeClient(root: string, f = fetch): ApiClient {
|
||||||
default:
|
default:
|
||||||
throw new InvalidArgumentError('unknown request type');
|
throw new InvalidArgumentError('unknown request type');
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
async restart(token: string): Promise<boolean> {
|
||||||
|
const path = makeApiUrl(root, 'restart');
|
||||||
|
path.searchParams.append('token', token);
|
||||||
|
|
||||||
|
const res = await f(path, {
|
||||||
|
method: 'POST',
|
||||||
|
});
|
||||||
|
return res.status === STATUS_SUCCESS;
|
||||||
|
},
|
||||||
|
async status(token: string): Promise<Array<unknown>> {
|
||||||
|
const path = makeApiUrl(root, 'status');
|
||||||
|
path.searchParams.append('token', token);
|
||||||
|
|
||||||
|
const res = await f(path);
|
||||||
|
return res.json();
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import { mustExist } from '@apextoaster/js-utils';
|
import { mustExist } from '@apextoaster/js-utils';
|
||||||
import { Stack } from '@mui/material';
|
import { Button, Stack } from '@mui/material';
|
||||||
import { useQuery } from '@tanstack/react-query';
|
import { useMutation, useQuery } from '@tanstack/react-query';
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
import { useContext } from 'react';
|
import { useContext } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
@ -21,6 +21,9 @@ export function ModelControl() {
|
||||||
const setModel = useStore(state, (s) => s.setModel);
|
const setModel = useStore(state, (s) => s.setModel);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const token = '';
|
||||||
|
|
||||||
|
const restart = useMutation(['restart'], async () => client.restart(token));
|
||||||
const models = useQuery(['models'], async () => client.models(), {
|
const models = useQuery(['models'], async () => client.models(), {
|
||||||
staleTime: STALE_TIME,
|
staleTime: STALE_TIME,
|
||||||
});
|
});
|
||||||
|
@ -159,6 +162,7 @@ export function ModelControl() {
|
||||||
addToken('lora', name);
|
addToken('lora', name);
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
<Button onClick={() => restart.mutate()}>Restart</Button>
|
||||||
</Stack>
|
</Stack>
|
||||||
</Stack>;
|
</Stack>;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue