1
0
Fork 0

send token with admin requests, return worker status after restarting

This commit is contained in:
Sean Sube 2023-04-20 17:36:29 -05:00
parent df0e7dc57e
commit 30c96be24f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
7 changed files with 36 additions and 17 deletions

View File

@ -73,14 +73,14 @@ def main():
def run(): def run():
_server, app, pool = main() server, app, pool = main()
pool.start() pool.start()
def quit(p: DevicePoolExecutor): def quit(p: DevicePoolExecutor):
logger.info("shutting down workers") logger.info("shutting down workers")
p.join() p.join()
# TODO: print admin token logger.info("starting API server with admin token: %s", server.admin_token)
atexit.register(partial(quit, pool)) atexit.register(partial(quit, pool))
return app return app

View File

@ -18,10 +18,11 @@ def restart_workers(server: ServerContext, pool: DevicePoolExecutor):
return make_response(jsonify({})), 401 return make_response(jsonify({})), 401
logger.info("restarting worker pool") logger.info("restarting worker pool")
pool.join() pool.recycle(recycle_all=True)
pool.start()
logger.info("restarted worker pool") logger.info("restarted worker pool")
return jsonify(pool.status())
def worker_status(server: ServerContext, pool: DevicePoolExecutor): def worker_status(server: ServerContext, pool: DevicePoolExecutor):
return jsonify(pool.status()) return jsonify(pool.status())

View File

@ -21,6 +21,7 @@ class WorkerContext:
progress: "Queue[ProgressCommand]" progress: "Queue[ProgressCommand]"
last_progress: Optional[ProgressCommand] last_progress: Optional[ProgressCommand]
idle: "Value[bool]" idle: "Value[bool]"
timeout: float
def __init__( def __init__(
self, self,
@ -42,6 +43,7 @@ class WorkerContext:
self.active_pid = active_pid self.active_pid = active_pid
self.last_progress = None self.last_progress = None
self.idle = idle self.idle = idle
self.timeout = 1.0
def start(self, job: str) -> None: def start(self, job: str) -> None:
self.job = job self.job = job

View File

@ -27,6 +27,10 @@ class DevicePoolExecutor:
recycle_interval: float recycle_interval: float
leaking: List[Tuple[str, Process, WorkerContext]] leaking: List[Tuple[str, Process, WorkerContext]]
worker_cancel: Dict[str, "Value[bool]"]
worker_idle: Dict[str, "Value[bool]"]
context: Dict[str, WorkerContext] # Device -> Context context: Dict[str, WorkerContext] # Device -> Context
current: Dict[str, "Value[int]"] # Device -> pid current: Dict[str, "Value[int]"] # Device -> pid
pending: Dict[str, "Queue[JobCommand]"] pending: Dict[str, "Queue[JobCommand]"]
@ -51,7 +55,7 @@ class DevicePoolExecutor:
server: ServerContext, server: ServerContext,
devices: List[DeviceParams], devices: List[DeviceParams],
max_pending_per_worker: int = 100, max_pending_per_worker: int = 100,
join_timeout: float = 1.0, join_timeout: float = 5.0,
recycle_interval: float = 10, recycle_interval: float = 10,
progress_interval: float = 1.0, progress_interval: float = 1.0,
): ):
@ -76,6 +80,8 @@ class DevicePoolExecutor:
self.pending_jobs = [] self.pending_jobs = []
self.running_jobs = {} self.running_jobs = {}
self.total_jobs = {} self.total_jobs = {}
self.worker_cancel = {}
self.worker_idle = {}
self.logs = Queue(self.max_pending_per_worker) self.logs = Queue(self.max_pending_per_worker)
self.rlock = Lock() self.rlock = Lock()
@ -105,16 +111,19 @@ class DevicePoolExecutor:
current = Value("L", 0) current = Value("L", 0)
self.current[name] = current self.current[name] = current
self.worker_cancel[name] = Value("B", False)
self.worker_idle[name] = Value("B", False)
# create a new context and worker # create a new context and worker
context = WorkerContext( context = WorkerContext(
name, name,
device, device,
cancel=Value("B", False), cancel=self.worker_cancel[name],
progress=self.progress[name], progress=self.progress[name],
logs=self.logs, logs=self.logs,
pending=self.pending[name], pending=self.pending[name],
active_pid=current, active_pid=current,
idle=Value("B", False), idle=self.worker_idle[name],
) )
self.context[name] = context self.context[name] = context
@ -283,12 +292,12 @@ class DevicePoolExecutor:
def join_leaking(self): def join_leaking(self):
if len(self.leaking) > 0: if len(self.leaking) > 0:
for device, worker, context in self.leaking: for device, worker, context in self.leaking:
logger.warning( logger.debug(
"shutting down leaking worker %s for device %s", worker.pid, device "shutting down leaking worker %s for device %s", worker.pid, device
) )
worker.join(self.join_timeout) worker.join(self.join_timeout)
if worker.is_alive(): if worker.is_alive():
logger.error( logger.warning(
"leaking worker %s for device %s could not be shut down", "leaking worker %s for device %s could not be shut down",
worker.pid, worker.pid,
device, device,
@ -310,7 +319,7 @@ class DevicePoolExecutor:
self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()] self.leaking[:] = [dw for dw in self.leaking if dw[1].is_alive()]
def recycle(self): def recycle(self, recycle_all=False):
logger.debug("recycling worker pool") logger.debug("recycling worker pool")
with self.rlock: with self.rlock:
@ -323,14 +332,14 @@ class DevicePoolExecutor:
if not worker.is_alive(): if not worker.is_alive():
logger.warning("worker for device %s has died", device) logger.warning("worker for device %s has died", device)
needs_restart.append(device) needs_restart.append(device)
elif jobs > self.max_jobs_per_worker: elif recycle_all or jobs > self.max_jobs_per_worker:
logger.info( logger.info(
"shutting down worker for device %s after %s jobs", device, jobs "shutting down worker for device %s after %s jobs", device, jobs
) )
worker.join(self.join_timeout) worker.join(self.join_timeout)
if worker.is_alive(): if worker.is_alive():
logger.warning( logger.warning(
"worker %s for device %s could not be recycled in time", "worker %s for device %s could not be shut down in time",
worker.pid, worker.pid,
device, device,
) )

View File

@ -48,7 +48,7 @@ def worker_main(worker: WorkerContext, server: ServerContext):
exit(EXIT_REPLACED) exit(EXIT_REPLACED)
# wait briefly for the next job # wait briefly for the next job
job = worker.pending.get(timeout=1.0) job = worker.pending.get(timeout=worker.timeout)
logger.info("worker %s got job: %s", worker.device.device, job.name) logger.info("worker %s got job: %s", worker.device.device, job.name)
# clear flags and save the job name # clear flags and save the job name

View File

@ -62,4 +62,10 @@ export const LOCAL_CLIENT = {
async strings() { async strings() {
return {}; return {};
}, },
async restart(token) {
throw new NoServerError();
},
async status(token) {
throw new NoServerError();
}
} as ApiClient; } as ApiClient;

View File

@ -21,9 +21,12 @@ export function ModelControl() {
const setModel = useStore(state, (s) => s.setModel); const setModel = useStore(state, (s) => s.setModel);
const { t } = useTranslation(); const { t } = useTranslation();
const token = ''; // get token from query string
const query = new URLSearchParams(window.location.search);
const token = query.get('token');
const [hash, _setHash] = useHash();
const restart = useMutation(['restart'], async () => client.restart(token)); const restart = useMutation(['restart'], async () => client.restart(mustExist(token)));
const models = useQuery(['models'], async () => client.models(), { const models = useQuery(['models'], async () => client.models(), {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
@ -34,8 +37,6 @@ export function ModelControl() {
staleTime: STALE_TIME, staleTime: STALE_TIME,
}); });
const [hash, _setHash] = useHash();
function addToken(type: string, name: string, weight = 1.0) { function addToken(type: string, name: string, weight = 1.0) {
const tab = getTab(hash); const tab = getTab(hash);
const current = state.getState(); const current = state.getState();