1
0
Fork 0

send token with admin requests, return worker status after restarting

This commit is contained in:
Sean Sube 2023-04-20 17:36:29 -05:00
parent df0e7dc57e
commit 30c96be24f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 36 additions and 17 deletions

View File

@ -73,14 +73,14 @@ def main():
def run():
_server, app, pool = main()
server, app, pool = main()
pool.start()
def quit(p: DevicePoolExecutor):
logger.info("shutting down workers")
p.join()
# TODO: print admin token
logger.info("starting API server with admin token: %s", server.admin_token)
atexit.register(partial(quit, pool))
return app

View File

@ -18,10 +18,11 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
return make_response(jsonify({})), 401
logger.info("restarting worker pool")
pool.join()
pool.start()
pool.recycle(recycle_all=True)
logger.info("restarted worker pool")
return jsonify(pool.status())
def worker_status(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status())

View File

@ -21,6 +21,7 @@ class WorkerContext:
progress: "Queue[ProgressCommand]"
last_progress: Optional[ProgressCommand]
idle: "Value[bool]"
timeout: float
def __init__(
self,
@ -42,6 +43,7 @@ class WorkerContext:
self.active_pid = active_pid
self.last_progress = None
self.idle = idle
self.timeout = 1.0
def start(self, job: str) -> None:
self.job = job

View File

@ -27,6 +27,10 @@ class DevicePoolExecutor:
recycle_interval: float
leaking: List[Tuple[str, Process, WorkerContext]]
worker_cancel: Dict[str, "Value[bool]"]
worker_idle: Dict[str, "Value[bool]"]
context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"] # Device -> pid
pending: Dict[str, "Queue[JobCommand]"]
@ -51,7 +55,7 @@ class DevicePoolExecutor:
server: ServerContext,
devices: List[DeviceParams],
max_pending_per_worker: int = 100,
join_timeout: float = 1.0,
join_timeout: float = 5.0,
recycle_interval: float = 10,
progress_interval: float = 1.0,
):
@ -76,6 +80,8 @@ class DevicePoolExecutor:
self.pending_jobs = []
self.running_jobs = {}
self.total_jobs = {}
self.worker_cancel = {}
self.worker_idle = {}
self.logs = Queue(self.max_pending_per_worker)
self.rlock = Lock()
@ -105,16 +111,19 @@ class DevicePoolExecutor:
current = Value("L", 0)
self.current[name] = current
self.worker_cancel[name] = Value("B", False)
self.worker_idle[name] = Value("B", False)
# create a new context and worker
context = WorkerContext(
name,
device,
cancel=Value("B", False),
cancel=self.worker_cancel[name],
progress=self.progress[name],
logs=self.logs,
pending=self.pending[name],
active_pid=current,
idle=Value("B", False),
idle=self.worker_idle[name],
)
self.context[name] = context
@ -283,12 +292,12 @@ class DevicePoolExecutor:
def join_leaking(self):
if len(self.leaking) > 0:
for device, worker, context in self.leaking:
logger.warning(
logger.debug(
"shutting down leaking worker %s for device %s", worker.pid, device
)
worker.join(self.join_timeout)
if worker.is_alive():
logger.error(
logger.warning(
"leaking worker %s for device %s could not be shut down",
worker.pid,
device,
@ -310,7 +319,7 @@ class DevicePoolExecutor:
self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()]
def recycle(self):
def recycle(self, recycle_all=False):
logger.debug("recycling worker pool")
with self.rlock:
@ -323,14 +332,14 @@ class DevicePoolExecutor:
if not worker.is_alive():
logger.warning("worker for device %s has died", device)
needs_restart.append(device)
elif jobs > self.max_jobs_per_worker:
elif recycle_all or jobs > self.max_jobs_per_worker:
logger.info(
"shutting down worker for device %s after %s jobs", device, jobs
)
worker.join(self.join_timeout)
if worker.is_alive():
logger.warning(
"worker %s for device %s could not be recycled in time",
"worker %s for device %s could not be shut down in time",
worker.pid,
device,
)

View File

@ -48,7 +48,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
exit(EXIT_REPLACED)
# wait briefly for the next job
job = worker.pending.get(timeout=1.0)
job = worker.pending.get(timeout=worker.timeout)
logger.info("worker %s got job: %s", worker.device.device, job.name)
# clear flags and save the job name

View File

@ -62,4 +62,10 @@ export const LOCAL_CLIENT = {
async strings() {
return {};
},
async restart(token) {
throw new NoServerError();
},
async status(token) {
throw new NoServerError();
}
} as ApiClient;

View File

@ -21,9 +21,12 @@ export function ModelControl() {
const setModel = useStore(state, (s) => s.setModel);
const { t } = useTranslation();
const token = '';
// get token from query string
const query = new URLSearchParams(window.location.search);
const token = query.get('token');
const [hash, _setHash] = useHash();
const restart = useMutation(['restart'], async () => client.restart(token));
const restart = useMutation(['restart'], async () => client.restart(mustExist(token)));
const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME,
});
@ -34,8 +37,6 @@ export function ModelControl() {
staleTime: STALE_TIME,
});
const [hash, _setHash] = useHash();
function addToken(type: string, name: string, weight = 1.0) {
const tab = getTab(hash);
const current = state.getState();