1
0
Fork 0

feat(api): add progress to ready endpoint

This commit is contained in:
Sean Sube 2023-02-04 10:16:30 -06:00
parent 1491a9e1e0
commit 294c831d02
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 16 additions and 7 deletions

View File

@ -1,7 +1,7 @@
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
from logging import getLogger
from multiprocessing import Value
from typing import Any, Callable, List, Union, Optional
from typing import Any, Callable, List, Optional, Tuple, Union
logger = getLogger(__name__)
@ -35,6 +35,9 @@ class JobContext:
else:
return self.devices[device_index]
def get_progress(self) -> int:
return self.progress.value
def get_progress_callback(self) -> Callable[..., None]:
def on_progress(step: int, timestep: int, latents: Any):
if self.is_cancelled():
@ -64,6 +67,9 @@ class Job:
self.future = future
self.key = key
def get_progress(self) -> int:
self.context.get_progress()
def set_cancel(self, cancel: bool = True):
self.context.set_cancel(cancel)
@ -94,13 +100,15 @@ class DevicePoolExecutor:
else:
job.set_cancel()
def done(self, key: str) -> bool:
def done(self, key: str) -> Tuple[bool, int]:
for job in self.jobs:
if job.key == key:
return job.future.done()
done = job.future.done()
progress = job.get_progress()
return (done, progress)
logger.warn('checking status for unknown key: %s', key)
return None
return (None, 0)
def prune(self):
self.jobs[:] = [job for job in self.jobs if job.future.done()]

View File

@ -331,8 +331,9 @@ if is_debug():
gc.set_debug(gc.DEBUG_STATS)
def ready_reply(ready: bool):
def ready_reply(ready: bool, progress: int = 0):
return jsonify({
'progress': progress,
'ready': ready,
})
@ -609,14 +610,14 @@ def chain():
def ready():
output_file = request.args.get('output', None)
done = executor.done(output_file)
done, progress = executor.done(output_file)
if done is None:
file = base_join(context.output_path, output_file)
if path.exists(file):
return ready_reply(True)
return ready_reply(done)
return ready_reply(done, progress=progress)
@app.route('/api/cancel', methods=['PUT'])