feat(api): switch to device pool for background workers
This commit is contained in:
parent
ecec0a2e56
commit
6426cff741
|
@ -0,0 +1,113 @@
|
|||
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
|
||||
from logging import getLogger
|
||||
from multiprocessing import Value
|
||||
from typing import Any, Callable, List, Union, Optional
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class JobContext:
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
devices: List[str],
|
||||
cancel: bool = False,
|
||||
device_index: int = -1,
|
||||
progress: int = 0,
|
||||
):
|
||||
self.key = key
|
||||
self.devices = list(devices)
|
||||
self.cancel = Value('B', cancel)
|
||||
self.device_index = Value('i', device_index)
|
||||
self.progress = Value('I', progress)
|
||||
|
||||
def is_cancelled(self) -> bool:
|
||||
return self.cancel.value
|
||||
|
||||
def get_device(self) -> str:
|
||||
'''
|
||||
Get the device assigned to this job.
|
||||
'''
|
||||
with self.device_index.get_lock():
|
||||
device_index = self.device_index.value
|
||||
if device_index < 0:
|
||||
raise Exception('job has not been assigned to a device')
|
||||
else:
|
||||
return self.devices[device_index]
|
||||
|
||||
def get_progress_callback(self) -> Callable[..., None]:
|
||||
def on_progress(step: int, timestep: int, latents: Any):
|
||||
if self.is_cancelled():
|
||||
raise Exception('job has been cancelled')
|
||||
else:
|
||||
self.set_progress(step)
|
||||
|
||||
return on_progress
|
||||
|
||||
def set_cancel(self, cancel: bool = True) -> None:
|
||||
with self.cancel.get_lock():
|
||||
self.cancel.value = cancel
|
||||
|
||||
def set_progress(self, progress: int) -> None:
|
||||
with self.progress.get_lock():
|
||||
self.progress.value = progress
|
||||
|
||||
|
||||
class Job:
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
future: Future,
|
||||
context: JobContext,
|
||||
):
|
||||
self.context = context
|
||||
self.future = future
|
||||
self.key = key
|
||||
|
||||
def set_cancel(self, cancel: bool = True):
|
||||
self.context.set_cancel(cancel)
|
||||
|
||||
def set_progress(self, progress: int):
|
||||
self.context.set_progress(progress)
|
||||
|
||||
|
||||
class DevicePoolExecutor:
|
||||
devices: List[str] = None
|
||||
jobs: List[Job] = None
|
||||
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
|
||||
|
||||
def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]]):
|
||||
self.devices = devices
|
||||
self.jobs = []
|
||||
self.pool = pool or ThreadPoolExecutor(len(devices))
|
||||
|
||||
def cancel(self, key: str) -> bool:
|
||||
'''
|
||||
Cancel a job. If the job has not been started, this will cancel
|
||||
the future and never execute it. If the job has been started, it
|
||||
should be cancelled on the next progress callback.
|
||||
'''
|
||||
for job in self.jobs:
|
||||
if job.key == key:
|
||||
if job.future.cancel():
|
||||
return True
|
||||
else:
|
||||
with job.cancel.get_lock():
|
||||
job.cancel.value = True
|
||||
|
||||
def done(self, key: str) -> bool:
|
||||
for job in self.jobs:
|
||||
if job.key == key:
|
||||
return job.future.done()
|
||||
|
||||
logger.warn('checking status for unknown key: %s', key)
|
||||
return None
|
||||
|
||||
def prune(self):
|
||||
self.jobs[:] = [job for job in self.jobs if job.future.done()]
|
||||
|
||||
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
|
||||
context = JobContext(key, self.devices, device_index=0)
|
||||
future = self.pool.submit(fn, context, *args, **kwargs)
|
||||
job = Job(key, future, context)
|
||||
self.jobs.append(job)
|
|
@ -9,6 +9,9 @@ from typing import Any
|
|||
from ..chain import (
|
||||
upscale_outpaint,
|
||||
)
|
||||
from ..device_pool import (
|
||||
JobContext,
|
||||
)
|
||||
from ..params import (
|
||||
ImageParams,
|
||||
Border,
|
||||
|
@ -38,18 +41,21 @@ logger = getLogger(__name__)
|
|||
|
||||
|
||||
def run_txt2img_pipeline(
|
||||
ctx: ServerContext,
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams
|
||||
) -> None:
|
||||
device = job.get_device()
|
||||
pipe = load_pipeline(OnnxStableDiffusionPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
params.model, params.provider, params.scheduler, device=device)
|
||||
|
||||
latents = get_latents_from_seed(params.seed, size)
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
progress = job.get_progress_callback()
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
height=size.height,
|
||||
|
@ -59,13 +65,14 @@ def run_txt2img_pipeline(
|
|||
latents=latents,
|
||||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
callback=progress,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
ctx, StageParams(), params, image, upscale=upscale)
|
||||
server, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = save_image(ctx, output, image)
|
||||
save_params(ctx, output, params, size, upscale=upscale)
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del image
|
||||
del result
|
||||
|
@ -75,18 +82,21 @@ def run_txt2img_pipeline(
|
|||
|
||||
|
||||
def run_img2img_pipeline(
|
||||
ctx: ServerContext,
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
source_image: Image.Image,
|
||||
strength: float,
|
||||
) -> None:
|
||||
device = job.get_device()
|
||||
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
|
||||
params.model, params.provider, params.scheduler)
|
||||
params.model, params.provider, params.scheduler, device=device)
|
||||
|
||||
rng = np.random.RandomState(params.seed)
|
||||
|
||||
progress = job.get_progress_callback()
|
||||
result = pipe(
|
||||
params.prompt,
|
||||
generator=rng,
|
||||
|
@ -95,14 +105,15 @@ def run_img2img_pipeline(
|
|||
negative_prompt=params.negative_prompt,
|
||||
num_inference_steps=params.steps,
|
||||
strength=strength,
|
||||
callback=progress,
|
||||
)
|
||||
image = result.images[0]
|
||||
image = run_upscale_correction(
|
||||
ctx, StageParams(), params, image, upscale=upscale)
|
||||
server, StageParams(), params, image, upscale=upscale)
|
||||
|
||||
dest = save_image(ctx, output, image)
|
||||
dest = save_image(server, output, image)
|
||||
size = Size(*source_image.size)
|
||||
save_params(ctx, output, params, size, upscale=upscale)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del image
|
||||
del result
|
||||
|
@ -112,7 +123,8 @@ def run_img2img_pipeline(
|
|||
|
||||
|
||||
def run_inpaint_pipeline(
|
||||
ctx: ServerContext,
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
|
@ -125,9 +137,13 @@ def run_inpaint_pipeline(
|
|||
strength: float,
|
||||
fill_color: str,
|
||||
) -> None:
|
||||
device = job.get_device()
|
||||
progress = job.get_progress_callback()
|
||||
stage = StageParams()
|
||||
|
||||
# TODO: pass device, progress
|
||||
image = upscale_outpaint(
|
||||
ctx,
|
||||
server,
|
||||
stage,
|
||||
params,
|
||||
source_image,
|
||||
|
@ -146,10 +162,10 @@ def run_inpaint_pipeline(
|
|||
'output image size does not match source, skipping post-blend')
|
||||
|
||||
image = run_upscale_correction(
|
||||
ctx, stage, params, image, upscale=upscale)
|
||||
server, stage, params, image, upscale=upscale)
|
||||
|
||||
dest = save_image(ctx, output, image)
|
||||
save_params(ctx, output, params, size, upscale=upscale, border=border)
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale, border=border)
|
||||
|
||||
del image
|
||||
run_gc()
|
||||
|
@ -158,18 +174,24 @@ def run_inpaint_pipeline(
|
|||
|
||||
|
||||
def run_upscale_pipeline(
|
||||
ctx: ServerContext,
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
params: ImageParams,
|
||||
size: Size,
|
||||
output: str,
|
||||
upscale: UpscaleParams,
|
||||
source_image: Image.Image,
|
||||
) -> None:
|
||||
image = run_upscale_correction(
|
||||
ctx, StageParams(), params, source_image, upscale=upscale)
|
||||
device = job.get_device()
|
||||
progress = job.get_progress_callback()
|
||||
stage = StageParams()
|
||||
|
||||
dest = save_image(ctx, output, image)
|
||||
save_params(ctx, output, params, size, upscale=upscale)
|
||||
# TODO: pass device, progress
|
||||
image = run_upscale_correction(
|
||||
server, stage, params, source_image, upscale=upscale)
|
||||
|
||||
dest = save_image(server, output, image)
|
||||
save_params(server, output, params, size, upscale=upscale)
|
||||
|
||||
del image
|
||||
run_gc()
|
||||
|
|
|
@ -41,6 +41,9 @@ from .chain import (
|
|||
upscale_stable_diffusion,
|
||||
ChainPipeline,
|
||||
)
|
||||
from .device_pool import (
|
||||
DevicePoolExecutor,
|
||||
)
|
||||
from .diffusion.run import (
|
||||
run_img2img_pipeline,
|
||||
run_inpaint_pipeline,
|
||||
|
@ -320,22 +323,14 @@ load_params(context)
|
|||
load_platforms()
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
|
||||
app.config['EXECUTOR_PROPAGATE_EXCEPTIONS'] = True
|
||||
|
||||
CORS(app, origins=context.cors_origin)
|
||||
executor = Executor(app)
|
||||
|
||||
executor = DevicePoolExecutor(available_platforms)
|
||||
|
||||
if is_debug():
|
||||
gc.set_debug(gc.DEBUG_STATS)
|
||||
|
||||
|
||||
# TODO: these two use context
|
||||
|
||||
def get_model_path(model: str):
|
||||
return base_join(context.model_path, model)
|
||||
|
||||
|
||||
def ready_reply(ready: bool):
|
||||
return jsonify({
|
||||
'ready': ready,
|
||||
|
@ -350,6 +345,10 @@ def error_reply(err: str):
|
|||
return response
|
||||
|
||||
|
||||
def get_model_path(model: str):
|
||||
return base_join(context.model_path, model)
|
||||
|
||||
|
||||
def serve_bundle_file(filename='index.html'):
|
||||
return send_from_directory(path.join('..', context.bundle_path), filename)
|
||||
|
||||
|
@ -439,8 +438,8 @@ def img2img():
|
|||
logger.info("img2img job queued for: %s", output)
|
||||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
executor.submit_stored(output, run_img2img_pipeline,
|
||||
context, params, output, upscale, source_image, strength)
|
||||
executor.submit(output, run_img2img_pipeline,
|
||||
context, params, output, upscale, source_image, strength)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
@ -457,7 +456,7 @@ def txt2img():
|
|||
size)
|
||||
logger.info("txt2img job queued for: %s", output)
|
||||
|
||||
executor.submit_stored(
|
||||
executor.submit(
|
||||
output, run_txt2img_pipeline, context, params, size, output, upscale)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
@ -512,7 +511,7 @@ def inpaint():
|
|||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
mask_image.thumbnail((size.width, size.height))
|
||||
executor.submit_stored(
|
||||
executor.submit(
|
||||
output,
|
||||
run_inpaint_pipeline,
|
||||
context,
|
||||
|
@ -550,8 +549,8 @@ def upscale():
|
|||
logger.info("upscale job queued for: %s", output)
|
||||
|
||||
source_image.thumbnail((size.width, size.height))
|
||||
executor.submit_stored(output, run_upscale_pipeline,
|
||||
context, params, size, output, upscale, source_image)
|
||||
executor.submit(output, run_upscale_pipeline,
|
||||
context, params, size, output, upscale, source_image)
|
||||
|
||||
return jsonify(json_params(output, params, size, upscale=upscale))
|
||||
|
||||
|
@ -600,8 +599,8 @@ def chain():
|
|||
|
||||
# build and run chain pipeline
|
||||
empty_source = Image.new('RGB', (size.width, size.height))
|
||||
executor.submit_stored(output, pipeline, context,
|
||||
params, empty_source, output=output, size=size)
|
||||
executor.submit(output, pipeline, context,
|
||||
params, empty_source, output=output, size=size)
|
||||
|
||||
return jsonify(json_params(output, params, size))
|
||||
|
||||
|
@ -610,19 +609,25 @@ def chain():
|
|||
def ready():
|
||||
output_file = request.args.get('output', None)
|
||||
|
||||
done = executor.futures.done(output_file)
|
||||
done = executor.done(output_file)
|
||||
|
||||
if done is None:
|
||||
file = base_join(context.output_path, output_file)
|
||||
if path.exists(file):
|
||||
return ready_reply(True)
|
||||
|
||||
elif done == True:
|
||||
executor.futures.pop(output_file)
|
||||
|
||||
return ready_reply(done)
|
||||
|
||||
|
||||
@app.route('/api/cancel', methods=['PUT'])
|
||||
def cancel():
|
||||
output_file = request.args.get('output', None)
|
||||
|
||||
cancel = executor.cancel(output_file)
|
||||
|
||||
return ready_reply(cancel)
|
||||
|
||||
|
||||
@app.route('/output/<path:filename>')
|
||||
def output(filename: str):
|
||||
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False)
|
||||
|
|
|
@ -1,3 +1 @@
|
|||
*.png
|
||||
*.jpg
|
||||
*.jpeg
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue