feat(api): add progress to ready endpoint
This commit is contained in:
parent
1491a9e1e0
commit
294c831d02
|
@ -1,7 +1,7 @@
|
|||
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from logging import getLogger
|
||||
from multiprocessing import Value
|
||||
from typing import Any, Callable, List, Union, Optional
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -35,6 +35,9 @@ class JobContext:
|
|||
else:
|
||||
return self.devices[device_index]
|
||||
|
||||
def get_progress(self) -> int:
|
||||
return self.progress.value
|
||||
|
||||
def get_progress_callback(self) -> Callable[..., None]:
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
if self.is_cancelled():
|
||||
|
@ -64,6 +67,9 @@ class Job:
|
|||
self.future = future
|
||||
self.key = key
|
||||
|
||||
def get_progress(self) -> int:
|
||||
self.context.get_progress()
|
||||
|
||||
def set_cancel(self, cancel: bool = True):
|
||||
self.context.set_cancel(cancel)
|
||||
|
||||
|
@ -94,13 +100,15 @@ class DevicePoolExecutor:
|
|||
else:
|
||||
job.set_cancel()
|
||||
|
||||
def done(self, key: str) -> bool:
|
||||
def done(self, key: str) -> Tuple[bool, int]:
|
||||
for job in self.jobs:
|
||||
if job.key == key:
|
||||
return job.future.done()
|
||||
done = job.future.done()
|
||||
progress = job.get_progress()
|
||||
return (done, progress)
|
||||
|
||||
logger.warn('checking status for unknown key: %s', key)
|
||||
return None
|
||||
return (None, 0)
|
||||
|
||||
def prune(self):
|
||||
self.jobs[:] = [job for job in self.jobs if job.future.done()]
|
||||
|
|
|
@ -331,8 +331,9 @@ if is_debug():
|
|||
gc.set_debug(gc.DEBUG_STATS)
|
||||
|
||||
|
||||
def ready_reply(ready: bool):
|
||||
def ready_reply(ready: bool, progress: int = 0):
|
||||
return jsonify({
|
||||
'progress': progress,
|
||||
'ready': ready,
|
||||
})
|
||||
|
||||
|
@ -609,14 +610,14 @@ def chain():
|
|||
def ready():
|
||||
output_file = request.args.get('output', None)
|
||||
|
||||
done = executor.done(output_file)
|
||||
done, progress = executor.done(output_file)
|
||||
|
||||
if done is None:
|
||||
file = base_join(context.output_path, output_file)
|
||||
if path.exists(file):
|
||||
return ready_reply(True)
|
||||
|
||||
return ready_reply(done)
|
||||
return ready_reply(done, progress=progress)
|
||||
|
||||
|
||||
@app.route('/api/cancel', methods=['PUT'])
|
||||
|
|
Loading…
Reference in New Issue