From 3637f642c694d1a7f96fc51e895cf7a49967fa7c Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 4 Feb 2023 14:52:23 -0600 Subject: [PATCH] fix(api): pass job context and device to upscaling --- api/onnx_web/diffusion/run.py | 10 ++++------ api/onnx_web/serve.py | 22 ++++++++++++++++------ api/onnx_web/upscale.py | 8 ++++++-- common/pipelines/outpaint.json | 1 + 4 files changed, 27 insertions(+), 14 deletions(-) diff --git a/api/onnx_web/diffusion/run.py b/api/onnx_web/diffusion/run.py index 4bf3cccc..1b44c700 100644 --- a/api/onnx_web/diffusion/run.py +++ b/api/onnx_web/diffusion/run.py @@ -69,7 +69,7 @@ def run_txt2img_pipeline( ) image = result.images[0] image = run_upscale_correction( - server, StageParams(), params, image, upscale=upscale) + job, server, StageParams(), params, image, upscale=upscale) dest = save_image(server, output, image) save_params(server, output, params, size, upscale=upscale) @@ -109,7 +109,7 @@ def run_img2img_pipeline( ) image = result.images[0] image = run_upscale_correction( - server, StageParams(), params, image, upscale=upscale) + job, server, StageParams(), params, image, upscale=upscale) dest = save_image(server, output, image) size = Size(*source_image.size) @@ -141,7 +141,6 @@ def run_inpaint_pipeline( # progress = job.get_progress_callback() stage = StageParams() - # TODO: pass device, progress image = upscale_outpaint( server, stage, @@ -162,7 +161,7 @@ def run_inpaint_pipeline( 'output image size does not match source, skipping post-blend') image = run_upscale_correction( - server, stage, params, image, upscale=upscale) + job, server, stage, params, image, upscale=upscale) dest = save_image(server, output, image) save_params(server, output, params, size, upscale=upscale, border=border) @@ -186,9 +185,8 @@ def run_upscale_pipeline( # progress = job.get_progress_callback() stage = StageParams() - # TODO: pass device, progress image = run_upscale_correction( - server, stage, params, source_image, upscale=upscale) + job, server, stage, params, source_image, upscale=upscale) dest = save_image(server, output, image) save_params(server, output, params, size, upscale=upscale) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 1b4a1cd2..da210ee6 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -41,6 +41,7 @@ from .chain import ( ChainPipeline, ) from .device_pool import ( + DeviceParams, DevicePoolExecutor, ) from .diffusion.run import ( @@ -168,14 +169,23 @@ def url_from_rule(rule) -> str: return url_for(rule.endpoint, **options) -def pipeline_from_request() -> Tuple[ImageParams, Size]: +def pipeline_from_request() -> Tuple[DeviceParams, ImageParams, Size]: user = request.remote_addr + # platform stuff + device_name = request.args.get('platform', available_platforms[0].device) + device = None + + for platform in available_platforms: + if platform.device == device_name: + device = available_platforms[0] + + if device is None: + raise Exception('unknown device') + # pipeline stuff model = get_not_empty(request.args, 'model', get_config_value('model')) model_path = get_model_path(model) - provider = get_from_map(request.args, 'platform', - platform_providers, get_config_value('platform')) scheduler = get_from_map(request.args, 'scheduler', pipeline_schedulers, get_config_value('scheduler')) @@ -213,12 +223,12 @@ def pipeline_from_request() -> Tuple[ImageParams, Size]: seed = np.random.randint(np.iinfo(np.int32).max) logger.info("request from %s: %s rounds of %s using %s on %s, %sx%s, %s, %s - %s", - user, steps, scheduler.__name__, model_path, provider, width, height, cfg, seed, prompt) + user, steps, scheduler.__name__, model_path, device.provider, width, height, cfg, seed, prompt) - params = ImageParams(model_path, provider, scheduler, prompt, + params = ImageParams(model_path, scheduler, prompt, negative_prompt, cfg, steps, seed) size = Size(width, height) - return (params, size) + return (device, params, size) def border_from_request() -> Border: diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 6917a146..3d765c43 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -7,6 +7,9 @@ from .chain import ( upscale_resrgan, ChainPipeline, ) +from .device_pool import ( + JobContext, +) from .params import ( ImageParams, SizeChart, @@ -21,7 +24,8 @@ logger = getLogger(__name__) def run_upscale_correction( - ctx: ServerContext, + job: JobContext, + server: ServerContext, stage: StageParams, params: ImageParams, image: Image.Image, @@ -51,4 +55,4 @@ def run_upscale_correction( outscale=1) chain.append((correct_gfpgan, stage, None)) - return chain(ctx, params, image, prompt=params.prompt, upscale=upscale) + return chain(job, server, params, image, prompt=params.prompt, upscale=upscale) diff --git a/common/pipelines/outpaint.json b/common/pipelines/outpaint.json index adc57120..cf45aeb1 100644 --- a/common/pipelines/outpaint.json +++ b/common/pipelines/outpaint.json @@ -12,6 +12,7 @@ "type": "upscale-outpaint", "params": { "border": 256, + "model": "stable-diffusion-onnx-v1-inpainting", "prompt": "a magical wizard in a robe fighting a dragon" } },