send token with admin requests, return worker status after restarting
This commit is contained in:
parent
df0e7dc57e
commit
30c96be24f
|
@ -73,14 +73,14 @@ def main():
|
|||
|
||||
|
||||
def run():
|
||||
_server, app, pool = main()
|
||||
server, app, pool = main()
|
||||
pool.start()
|
||||
|
||||
def quit(p: DevicePoolExecutor):
|
||||
logger.info("shutting down workers")
|
||||
p.join()
|
||||
|
||||
# TODO: print admin token
|
||||
logger.info("starting API server with admin token: %s", server.admin_token)
|
||||
atexit.register(partial(quit, pool))
|
||||
return app
|
||||
|
||||
|
|
|
@ -18,10 +18,11 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
|
|||
return make_response(jsonify({})), 401
|
||||
|
||||
logger.info("restarting worker pool")
|
||||
pool.join()
|
||||
pool.start()
|
||||
pool.recycle(recycle_all=True)
|
||||
logger.info("restarted worker pool")
|
||||
|
||||
return jsonify(pool.status())
|
||||
|
||||
|
||||
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
|
||||
return jsonify(pool.status())
|
||||
|
|
|
@ -21,6 +21,7 @@ class WorkerContext:
|
|||
progress: "Queue[ProgressCommand]"
|
||||
last_progress: Optional[ProgressCommand]
|
||||
idle: "Value[bool]"
|
||||
timeout: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -42,6 +43,7 @@ class WorkerContext:
|
|||
self.active_pid = active_pid
|
||||
self.last_progress = None
|
||||
self.idle = idle
|
||||
self.timeout = 1.0
|
||||
|
||||
def start(self, job: str) -> None:
|
||||
self.job = job
|
||||
|
|
|
@ -27,6 +27,10 @@ class DevicePoolExecutor:
|
|||
recycle_interval: float
|
||||
|
||||
leaking: List[Tuple[str, Process, WorkerContext]]
|
||||
|
||||
worker_cancel: Dict[str, "Value[bool]"]
|
||||
worker_idle: Dict[str, "Value[bool]"]
|
||||
|
||||
context: Dict[str, WorkerContext] # Device -> Context
|
||||
current: Dict[str, "Value[int]"] # Device -> pid
|
||||
pending: Dict[str, "Queue[JobCommand]"]
|
||||
|
@ -51,7 +55,7 @@ class DevicePoolExecutor:
|
|||
server: ServerContext,
|
||||
devices: List[DeviceParams],
|
||||
max_pending_per_worker: int = 100,
|
||||
join_timeout: float = 1.0,
|
||||
join_timeout: float = 5.0,
|
||||
recycle_interval: float = 10,
|
||||
progress_interval: float = 1.0,
|
||||
):
|
||||
|
@ -76,6 +80,8 @@ class DevicePoolExecutor:
|
|||
self.pending_jobs = []
|
||||
self.running_jobs = {}
|
||||
self.total_jobs = {}
|
||||
self.worker_cancel = {}
|
||||
self.worker_idle = {}
|
||||
|
||||
self.logs = Queue(self.max_pending_per_worker)
|
||||
self.rlock = Lock()
|
||||
|
@ -105,16 +111,19 @@ class DevicePoolExecutor:
|
|||
current = Value("L", 0)
|
||||
self.current[name] = current
|
||||
|
||||
self.worker_cancel[name] = Value("B", False)
|
||||
self.worker_idle[name] = Value("B", False)
|
||||
|
||||
# create a new context and worker
|
||||
context = WorkerContext(
|
||||
name,
|
||||
device,
|
||||
cancel=Value("B", False),
|
||||
cancel=self.worker_cancel[name],
|
||||
progress=self.progress[name],
|
||||
logs=self.logs,
|
||||
pending=self.pending[name],
|
||||
active_pid=current,
|
||||
idle=Value("B", False),
|
||||
idle=self.worker_idle[name],
|
||||
)
|
||||
self.context[name] = context
|
||||
|
||||
|
@ -283,12 +292,12 @@ class DevicePoolExecutor:
|
|||
def join_leaking(self):
|
||||
if len(self.leaking) > 0:
|
||||
for device, worker, context in self.leaking:
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
"shutting down leaking worker %s for device %s", worker.pid, device
|
||||
)
|
||||
worker.join(self.join_timeout)
|
||||
if worker.is_alive():
|
||||
logger.error(
|
||||
logger.warning(
|
||||
"leaking worker %s for device %s could not be shut down",
|
||||
worker.pid,
|
||||
device,
|
||||
|
@ -310,7 +319,7 @@ class DevicePoolExecutor:
|
|||
|
||||
self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()]
|
||||
|
||||
def recycle(self):
|
||||
def recycle(self, recycle_all=False):
|
||||
logger.debug("recycling worker pool")
|
||||
|
||||
with self.rlock:
|
||||
|
@ -323,14 +332,14 @@ class DevicePoolExecutor:
|
|||
if not worker.is_alive():
|
||||
logger.warning("worker for device %s has died", device)
|
||||
needs_restart.append(device)
|
||||
elif jobs > self.max_jobs_per_worker:
|
||||
elif recycle_all or jobs > self.max_jobs_per_worker:
|
||||
logger.info(
|
||||
"shutting down worker for device %s after %s jobs", device, jobs
|
||||
)
|
||||
worker.join(self.join_timeout)
|
||||
if worker.is_alive():
|
||||
logger.warning(
|
||||
"worker %s for device %s could not be recycled in time",
|
||||
"worker %s for device %s could not be shut down in time",
|
||||
worker.pid,
|
||||
device,
|
||||
)
|
||||
|
|
|
@ -48,7 +48,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
|
|||
exit(EXIT_REPLACED)
|
||||
|
||||
# wait briefly for the next job
|
||||
job = worker.pending.get(timeout=1.0)
|
||||
job = worker.pending.get(timeout=worker.timeout)
|
||||
logger.info("worker %s got job: %s", worker.device.device, job.name)
|
||||
|
||||
# clear flags and save the job name
|
||||
|
|
|
@ -62,4 +62,10 @@ export const LOCAL_CLIENT = {
|
|||
async strings() {
|
||||
return {};
|
||||
},
|
||||
async restart(token) {
|
||||
throw new NoServerError();
|
||||
},
|
||||
async status(token) {
|
||||
throw new NoServerError();
|
||||
}
|
||||
} as ApiClient;
|
||||
|
|
|
@ -21,9 +21,12 @@ export function ModelControl() {
|
|||
const setModel = useStore(state, (s) => s.setModel);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const token = '';
|
||||
// get token from query string
|
||||
const query = new URLSearchParams(window.location.search);
|
||||
const token = query.get('token');
|
||||
const [hash, _setHash] = useHash();
|
||||
|
||||
const restart = useMutation(['restart'], async () => client.restart(token));
|
||||
const restart = useMutation(['restart'], async () => client.restart(mustExist(token)));
|
||||
const models = useQuery(['models'], async () => client.models(), {
|
||||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
@ -34,8 +37,6 @@ export function ModelControl() {
|
|||
staleTime: STALE_TIME,
|
||||
});
|
||||
|
||||
const [hash, _setHash] = useHash();
|
||||
|
||||
function addToken(type: string, name: string, weight = 1.0) {
|
||||
const tab = getTab(hash);
|
||||
const current = state.getState();
|
||||
|
|
Loading…
Reference in New Issue