feat(api): add progress to ready endpoint
This commit is contained in:
parent
1491a9e1e0
commit
294c831d02
|
@ -1,7 +1,7 @@
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
from typing import Any, Callable, List, Union, Optional
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -35,6 +35,9 @@ class JobContext:
|
||||||
else:
|
else:
|
||||||
return self.devices[device_index]
|
return self.devices[device_index]
|
||||||
|
|
||||||
|
def get_progress(self) -> int:
|
||||||
|
return self.progress.value
|
||||||
|
|
||||||
def get_progress_callback(self) -> Callable[..., None]:
|
def get_progress_callback(self) -> Callable[..., None]:
|
||||||
def on_progress(step: int, timestep: int, latents: Any):
|
def on_progress(step: int, timestep: int, latents: Any):
|
||||||
if self.is_cancelled():
|
if self.is_cancelled():
|
||||||
|
@ -64,6 +67,9 @@ class Job:
|
||||||
self.future = future
|
self.future = future
|
||||||
self.key = key
|
self.key = key
|
||||||
|
|
||||||
|
def get_progress(self) -> int:
|
||||||
|
self.context.get_progress()
|
||||||
|
|
||||||
def set_cancel(self, cancel: bool = True):
|
def set_cancel(self, cancel: bool = True):
|
||||||
self.context.set_cancel(cancel)
|
self.context.set_cancel(cancel)
|
||||||
|
|
||||||
|
@ -94,13 +100,15 @@ class DevicePoolExecutor:
|
||||||
else:
|
else:
|
||||||
job.set_cancel()
|
job.set_cancel()
|
||||||
|
|
||||||
def done(self, key: str) -> bool:
|
def done(self, key: str) -> Tuple[bool, int]:
|
||||||
for job in self.jobs:
|
for job in self.jobs:
|
||||||
if job.key == key:
|
if job.key == key:
|
||||||
return job.future.done()
|
done = job.future.done()
|
||||||
|
progress = job.get_progress()
|
||||||
|
return (done, progress)
|
||||||
|
|
||||||
logger.warn('checking status for unknown key: %s', key)
|
logger.warn('checking status for unknown key: %s', key)
|
||||||
return None
|
return (None, 0)
|
||||||
|
|
||||||
def prune(self):
|
def prune(self):
|
||||||
self.jobs[:] = [job for job in self.jobs if job.future.done()]
|
self.jobs[:] = [job for job in self.jobs if job.future.done()]
|
||||||
|
|
|
@ -331,8 +331,9 @@ if is_debug():
|
||||||
gc.set_debug(gc.DEBUG_STATS)
|
gc.set_debug(gc.DEBUG_STATS)
|
||||||
|
|
||||||
|
|
||||||
def ready_reply(ready: bool):
|
def ready_reply(ready: bool, progress: int = 0):
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
'progress': progress,
|
||||||
'ready': ready,
|
'ready': ready,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -609,14 +610,14 @@ def chain():
|
||||||
def ready():
|
def ready():
|
||||||
output_file = request.args.get('output', None)
|
output_file = request.args.get('output', None)
|
||||||
|
|
||||||
done = executor.done(output_file)
|
done, progress = executor.done(output_file)
|
||||||
|
|
||||||
if done is None:
|
if done is None:
|
||||||
file = base_join(context.output_path, output_file)
|
file = base_join(context.output_path, output_file)
|
||||||
if path.exists(file):
|
if path.exists(file):
|
||||||
return ready_reply(True)
|
return ready_reply(True)
|
||||||
|
|
||||||
return ready_reply(done)
|
return ready_reply(done, progress=progress)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/api/cancel', methods=['PUT'])
|
@app.route('/api/cancel', methods=['PUT'])
|
||||||
|
|
Loading…
Reference in New Issue