1
0
Fork 0

feat(api): switch to device pool for background workers

This commit is contained in:
Sean Sube 2023-02-04 10:06:22 -06:00
parent ecec0a2e56
commit 6426cff741
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 183 additions and 45 deletions

113
api/onnx_web/device_pool.py Normal file
View File

@ -0,0 +1,113 @@
from concurrent.futures import Future, ThreadPoolExecutor, ProcessPoolExecutor
from logging import getLogger
from multiprocessing import Value
from typing import Any, Callable, List, Union, Optional
logger = getLogger(__name__)
class JobContext:
def __init__(
self,
key: str,
devices: List[str],
cancel: bool = False,
device_index: int = -1,
progress: int = 0,
):
self.key = key
self.devices = list(devices)
self.cancel = Value('B', cancel)
self.device_index = Value('i', device_index)
self.progress = Value('I', progress)
def is_cancelled(self) -> bool:
return self.cancel.value
def get_device(self) -> str:
'''
Get the device assigned to this job.
'''
with self.device_index.get_lock():
device_index = self.device_index.value
if device_index < 0:
raise Exception('job has not been assigned to a device')
else:
return self.devices[device_index]
def get_progress_callback(self) -> Callable[..., None]:
def on_progress(step: int, timestep: int, latents: Any):
if self.is_cancelled():
raise Exception('job has been cancelled')
else:
self.set_progress(step)
return on_progress
def set_cancel(self, cancel: bool = True) -> None:
with self.cancel.get_lock():
self.cancel.value = cancel
def set_progress(self, progress: int) -> None:
with self.progress.get_lock():
self.progress.value = progress
class Job:
def __init__(
self,
key: str,
future: Future,
context: JobContext,
):
self.context = context
self.future = future
self.key = key
def set_cancel(self, cancel: bool = True):
self.context.set_cancel(cancel)
def set_progress(self, progress: int):
self.context.set_progress(progress)
class DevicePoolExecutor:
devices: List[str] = None
jobs: List[Job] = None
pool: Union[ProcessPoolExecutor, ThreadPoolExecutor] = None
def __init__(self, devices: List[str], pool: Optional[Union[ProcessPoolExecutor, ThreadPoolExecutor]]):
self.devices = devices
self.jobs = []
self.pool = pool or ThreadPoolExecutor(len(devices))
def cancel(self, key: str) -> bool:
'''
Cancel a job. If the job has not been started, this will cancel
the future and never execute it. If the job has been started, it
should be cancelled on the next progress callback.
'''
for job in self.jobs:
if job.key == key:
if job.future.cancel():
return True
else:
with job.cancel.get_lock():
job.cancel.value = True
def done(self, key: str) -> bool:
for job in self.jobs:
if job.key == key:
return job.future.done()
logger.warn('checking status for unknown key: %s', key)
return None
def prune(self):
self.jobs[:] = [job for job in self.jobs if job.future.done()]
def submit(self, key: str, fn: Callable[..., None], /, *args, **kwargs) -> None:
context = JobContext(key, self.devices, device_index=0)
future = self.pool.submit(fn, context, *args, **kwargs)
job = Job(key, future, context)
self.jobs.append(job)

View File

@ -9,6 +9,9 @@ from typing import Any
from ..chain import (
upscale_outpaint,
)
from ..device_pool import (
JobContext,
)
from ..params import (
ImageParams,
Border,
@ -38,18 +41,21 @@ logger = getLogger(__name__)
def run_txt2img_pipeline(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionPipeline,
params.model, params.provider, params.scheduler)
params.model, params.provider, params.scheduler, device=device)
latents = get_latents_from_seed(params.seed, size)
rng = np.random.RandomState(params.seed)
progress = job.get_progress_callback()
result = pipe(
params.prompt,
height=size.height,
@ -59,13 +65,14 @@ def run_txt2img_pipeline(
latents=latents,
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
callback=progress,
)
image = result.images[0]
image = run_upscale_correction(
ctx, StageParams(), params, image, upscale=upscale)
server, StageParams(), params, image, upscale=upscale)
dest = save_image(ctx, output, image)
save_params(ctx, output, params, size, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
del image
del result
@ -75,18 +82,21 @@ def run_txt2img_pipeline(
def run_img2img_pipeline(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
params: ImageParams,
output: str,
upscale: UpscaleParams,
source_image: Image.Image,
strength: float,
) -> None:
device = job.get_device()
pipe = load_pipeline(OnnxStableDiffusionImg2ImgPipeline,
params.model, params.provider, params.scheduler)
params.model, params.provider, params.scheduler, device=device)
rng = np.random.RandomState(params.seed)
progress = job.get_progress_callback()
result = pipe(
params.prompt,
generator=rng,
@ -95,14 +105,15 @@ def run_img2img_pipeline(
negative_prompt=params.negative_prompt,
num_inference_steps=params.steps,
strength=strength,
callback=progress,
)
image = result.images[0]
image = run_upscale_correction(
ctx, StageParams(), params, image, upscale=upscale)
server, StageParams(), params, image, upscale=upscale)
dest = save_image(ctx, output, image)
dest = save_image(server, output, image)
size = Size(*source_image.size)
save_params(ctx, output, params, size, upscale=upscale)
save_params(server, output, params, size, upscale=upscale)
del image
del result
@ -112,7 +123,8 @@ def run_img2img_pipeline(
def run_inpaint_pipeline(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
@ -125,9 +137,13 @@ def run_inpaint_pipeline(
strength: float,
fill_color: str,
) -> None:
device = job.get_device()
progress = job.get_progress_callback()
stage = StageParams()
# TODO: pass device, progress
image = upscale_outpaint(
ctx,
server,
stage,
params,
source_image,
@ -146,10 +162,10 @@ def run_inpaint_pipeline(
'output image size does not match source, skipping post-blend')
image = run_upscale_correction(
ctx, stage, params, image, upscale=upscale)
server, stage, params, image, upscale=upscale)
dest = save_image(ctx, output, image)
save_params(ctx, output, params, size, upscale=upscale, border=border)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale, border=border)
del image
run_gc()
@ -158,18 +174,24 @@ def run_inpaint_pipeline(
def run_upscale_pipeline(
ctx: ServerContext,
job: JobContext,
server: ServerContext,
params: ImageParams,
size: Size,
output: str,
upscale: UpscaleParams,
source_image: Image.Image,
) -> None:
image = run_upscale_correction(
ctx, StageParams(), params, source_image, upscale=upscale)
device = job.get_device()
progress = job.get_progress_callback()
stage = StageParams()
dest = save_image(ctx, output, image)
save_params(ctx, output, params, size, upscale=upscale)
# TODO: pass device, progress
image = run_upscale_correction(
server, stage, params, source_image, upscale=upscale)
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)
del image
run_gc()

View File

@ -41,6 +41,9 @@ from .chain import (
upscale_stable_diffusion,
ChainPipeline,
)
from .device_pool import (
DevicePoolExecutor,
)
from .diffusion.run import (
run_img2img_pipeline,
run_inpaint_pipeline,
@ -320,22 +323,14 @@ load_params(context)
load_platforms()
app = Flask(__name__)
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
app.config['EXECUTOR_PROPAGATE_EXCEPTIONS'] = True
CORS(app, origins=context.cors_origin)
executor = Executor(app)
executor = DevicePoolExecutor(available_platforms)
if is_debug():
gc.set_debug(gc.DEBUG_STATS)
# TODO: these two use context
def get_model_path(model: str):
return base_join(context.model_path, model)
def ready_reply(ready: bool):
return jsonify({
'ready': ready,
@ -350,6 +345,10 @@ def error_reply(err: str):
return response
def get_model_path(model: str):
return base_join(context.model_path, model)
def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', context.bundle_path), filename)
@ -439,8 +438,8 @@ def img2img():
logger.info("img2img job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
executor.submit_stored(output, run_img2img_pipeline,
context, params, output, upscale, source_image, strength)
executor.submit(output, run_img2img_pipeline,
context, params, output, upscale, source_image, strength)
return jsonify(json_params(output, params, size, upscale=upscale))
@ -457,7 +456,7 @@ def txt2img():
size)
logger.info("txt2img job queued for: %s", output)
executor.submit_stored(
executor.submit(
output, run_txt2img_pipeline, context, params, size, output, upscale)
return jsonify(json_params(output, params, size, upscale=upscale))
@ -512,7 +511,7 @@ def inpaint():
source_image.thumbnail((size.width, size.height))
mask_image.thumbnail((size.width, size.height))
executor.submit_stored(
executor.submit(
output,
run_inpaint_pipeline,
context,
@ -550,8 +549,8 @@ def upscale():
logger.info("upscale job queued for: %s", output)
source_image.thumbnail((size.width, size.height))
executor.submit_stored(output, run_upscale_pipeline,
context, params, size, output, upscale, source_image)
executor.submit(output, run_upscale_pipeline,
context, params, size, output, upscale, source_image)
return jsonify(json_params(output, params, size, upscale=upscale))
@ -600,8 +599,8 @@ def chain():
# build and run chain pipeline
empty_source = Image.new('RGB', (size.width, size.height))
executor.submit_stored(output, pipeline, context,
params, empty_source, output=output, size=size)
executor.submit(output, pipeline, context,
params, empty_source, output=output, size=size)
return jsonify(json_params(output, params, size))
@ -610,19 +609,25 @@ def chain():
def ready():
output_file = request.args.get('output', None)
done = executor.futures.done(output_file)
done = executor.done(output_file)
if done is None:
file = base_join(context.output_path, output_file)
if path.exists(file):
return ready_reply(True)
elif done == True:
executor.futures.pop(output_file)
return ready_reply(done)
@app.route('/api/cancel', methods=['PUT'])
def cancel():
output_file = request.args.get('output', None)
cancel = executor.cancel(output_file)
return ready_reply(cancel)
@app.route('/output/<path:filename>')
def output(filename: str):
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False)

4
outputs/.gitignore vendored
View File

@ -1,3 +1 @@
*.png
*.jpg
*.jpeg
*