From 483b8e3f19801018416c5f0fb67c4e97e06164a8 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Wed, 25 Jan 2023 20:31:39 -0600 Subject: [PATCH] lint(api): consolidate upscale/correction logic --- api/onnx_web/__init__.py | 17 ++++++++++------- api/onnx_web/{pipeline.py => diffusion.py} | 15 +++++---------- api/onnx_web/serve.py | 13 ++++++------- api/onnx_web/upscale.py | 12 ++++++++++++ 4 files changed, 33 insertions(+), 24 deletions(-) rename api/onnx_web/{pipeline.py => diffusion.py} (94%) diff --git a/api/onnx_web/__init__.py b/api/onnx_web/__init__.py index 2a3807ac..b6c372f6 100644 --- a/api/onnx_web/__init__.py +++ b/api/onnx_web/__init__.py @@ -1,3 +1,10 @@ +from .diffusion import ( + get_latents_from_seed, + load_pipeline, + run_img2img_pipeline, + run_inpaint_pipeline, + run_txt2img_pipeline, +) from .image import ( expand_image, mask_filter_gaussian_multiply, @@ -10,15 +17,9 @@ from .image import ( noise_source_normal, noise_source_uniform, ) -from .pipeline import ( - get_latents_from_seed, - load_pipeline, - run_img2img_pipeline, - run_inpaint_pipeline, - run_txt2img_pipeline, -) from .upscale import ( make_resrgan, + run_upscale_pipeline, upscale_gfpgan, upscale_resrgan, UpscaleParams, @@ -26,7 +27,9 @@ from .upscale import ( from .utils import ( get_and_clamp_float, get_and_clamp_int, + get_from_list, get_from_map, + get_not_empty, safer_join, BaseParams, Border, diff --git a/api/onnx_web/pipeline.py b/api/onnx_web/diffusion.py similarity index 94% rename from api/onnx_web/pipeline.py rename to api/onnx_web/diffusion.py index 8fab00e9..f58dd5d7 100644 --- a/api/onnx_web/pipeline.py +++ b/api/onnx_web/diffusion.py @@ -61,7 +61,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu gc.collect() torch.cuda.empty_cache() - print('loading different pipeline') + print('loading new pipeline') pipe = pipeline.from_pretrained( model, provider=provider, @@ -77,7 +77,7 @@ def load_pipeline(pipeline: DiffusionPipeline, model: str, provider: str, schedu last_pipeline_scheduler = scheduler if last_pipeline_scheduler != scheduler: - print('changing pipeline scheduler') + print('loading new scheduler') scheduler = scheduler.from_pretrained( model, subfolder='scheduler') @@ -117,9 +117,7 @@ def run_txt2img_pipeline( num_inference_steps=params.steps, ) image = result.images[0] - - if upscale.faces or upscale.scale > 1: - image = upscale_resrgan(ctx, upscale, image) + image = run_upscale_pipeline(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) @@ -153,9 +151,7 @@ def run_img2img_pipeline( strength=strength, ) image = result.images[0] - - if upscale.faces or upscale.scale > 1: - image = upscale_resrgan(ctx, upscale, image) + image = run_upscale_pipeline(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) @@ -219,8 +215,7 @@ def run_inpaint_pipeline( else: print('output image size does not match source, skipping post-blend') - if upscale.faces or upscale.scale > 1: - image = upscale_resrgan(ctx, upscale, image) + image = run_upscale_pipeline(ctx, upscale, image) dest = safer_join(ctx.output_path, output) image.save(dest) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 4d268b24..1396da7d 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -1,5 +1,4 @@ from diffusers import ( - # schedulers DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, @@ -23,6 +22,12 @@ from onnxruntime import get_available_providers from os import makedirs, path, scandir from typing import Tuple +from .diffusion import ( + run_img2img_pipeline, + run_inpaint_pipeline, + run_txt2img_pipeline, + run_upscale_pipeline, +) from .image import ( # mask filters mask_filter_gaussian_multiply, @@ -36,12 +41,6 @@ from .image import ( noise_source_normal, noise_source_uniform, ) -from .pipeline import ( - run_img2img_pipeline, - run_inpaint_pipeline, - run_txt2img_pipeline, - run_upscale_pipeline, -) from .upscale import ( UpscaleParams, ) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 88ff6122..3cd5b4c8 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -188,3 +188,15 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N image, has_aligned=False, only_center_face=False, paste_back=True, weight=params.face_strength) return output + + +def run_upscale_pipeline(ctx: ServerContext, params: UpscaleParams, image: Image) -> Image: + print('running upscale pipeline') + + if params.scale > 1: + image = upscale_resrgan(ctx, params, image) + + if params.faces: + image = upscale_gfpgan(ctx, params, image) + + return image \ No newline at end of file