From de4e7b0dc9b820da61a1dbe1be3dbffda68af5cb Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 6 Feb 2023 17:13:37 -0600 Subject: [PATCH] feat: add face outscale as its own parameter (#111) --- api/onnx_web/chain/blend_img2img.py | 3 +- api/onnx_web/chain/blend_inpaint.py | 4 +-- api/onnx_web/chain/correct_codeformer.py | 5 ++-- api/onnx_web/chain/correct_gfpgan.py | 21 ++++++++----- api/onnx_web/chain/reduce_thumbnail.py | 3 +- .../chain/upscale_stable_diffusion.py | 2 +- api/onnx_web/device_pool.py | 2 +- api/onnx_web/params.py | 4 +++ api/onnx_web/serve.py | 12 ++++---- api/params.json | 6 ++++ api/pyproject.toml | 16 ++++++++++ docs/dev-test.md | 30 +++++++++++++------ docs/user-guide.md | 20 +++++++++++++ gui/src/client.ts | 11 ++++--- gui/src/components/control/UpscaleControl.tsx | 15 +++++++++- gui/src/state.ts | 5 ++-- 16 files changed, 120 insertions(+), 39 deletions(-) diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py index d56c5cb6..a66b0fd6 100644 --- a/api/onnx_web/chain/blend_img2img.py +++ b/api/onnx_web/chain/blend_img2img.py @@ -4,6 +4,7 @@ import numpy as np import torch from diffusers import OnnxStableDiffusionImg2ImgPipeline from PIL import Image +from typing import Optional from ..device_pool import JobContext from ..diffusion.load import load_pipeline @@ -21,7 +22,7 @@ def blend_img2img( source_image: Image.Image, *, strength: float, - prompt: str = None, + prompt: Optional[str] = None, **kwargs, ) -> Image.Image: prompt = prompt or params.prompt diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py index 9ecba374..a980e5a5 100644 --- a/api/onnx_web/chain/blend_inpaint.py +++ b/api/onnx_web/chain/blend_inpaint.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple import numpy as np import torch @@ -25,7 +25,7 @@ def blend_inpaint( source_image: Image.Image, *, expand: Border, - mask_image: Image.Image = None, + mask_image: Optional[Image.Image] = None, fill_color: str = "white", mask_filter: Callable = mask_filter_none, noise_source: Callable = noise_source_histogram, diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py index 60bea425..5dc4bdd8 100644 --- a/api/onnx_web/chain/correct_codeformer.py +++ b/api/onnx_web/chain/correct_codeformer.py @@ -4,7 +4,7 @@ from codeformer import CodeFormer from PIL import Image from ..device_pool import JobContext -from ..params import ImageParams, StageParams +from ..params import ImageParams, StageParams, UpscaleParams from ..utils import ServerContext logger = getLogger(__name__) @@ -20,11 +20,12 @@ def correct_codeformer( source: Image.Image, *, source_image: Image.Image = None, + upscale: UpscaleParams, **kwargs, ) -> Image.Image: device = job.get_device() # TODO: terrible names, fix image = source or source_image - pipe = CodeFormer(upscale=stage.outscale).to(device.torch_device()) + pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device()) return pipe(image) diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py index 4022c06c..5b9018de 100644 --- a/api/onnx_web/chain/correct_gfpgan.py +++ b/api/onnx_web/chain/correct_gfpgan.py @@ -4,23 +4,25 @@ from os import path import numpy as np from gfpgan import GFPGANer from PIL import Image +from typing import Optional from ..device_pool import JobContext from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams from ..utils import ServerContext, run_gc +from .upscale_resrgan import load_resrgan logger = getLogger(__name__) -last_pipeline_instance = None -last_pipeline_params = None +last_pipeline_instance: Optional[GFPGANer] = None +last_pipeline_params: Optional[str] = None -def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParams): +def load_gfpgan(server: ServerContext, stage: StageParams, upscale: UpscaleParams, device: DeviceParams): global last_pipeline_instance global last_pipeline_params - face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model)) + face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model)) if last_pipeline_instance is not None and face_path == last_pipeline_params: logger.info("reusing existing GFPGAN pipeline") @@ -28,12 +30,15 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParam logger.debug("loading GFPGAN model from %s", face_path) + upsampler = load_resrgan(server, upscale, device, tile=stage.tile_size) + # TODO: find a way to pass the ONNX model to underlying architectures gfpgan = GFPGANer( - model_path=face_path, - upscale=upscale.outscale, arch="clean", + bg_upsampler=upsampler, channel_multiplier=2, + model_path=face_path, + upscale=upscale.face_outscale, ) last_pipeline_instance = gfpgan @@ -46,7 +51,7 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParam def correct_gfpgan( job: JobContext, server: ServerContext, - _stage: StageParams, + stage: StageParams, _params: ImageParams, source_image: Image.Image, *, @@ -59,7 +64,7 @@ def correct_gfpgan( logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model) device = job.get_device() - gfpgan = load_gfpgan(server, upscale, device) + gfpgan = load_gfpgan(server, stage, upscale, device) output = np.array(source_image) _, _, output = gfpgan.enhance( diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py index b5b25130..b6beb9c6 100644 --- a/api/onnx_web/chain/reduce_thumbnail.py +++ b/api/onnx_web/chain/reduce_thumbnail.py @@ -19,6 +19,7 @@ def reduce_thumbnail( size: Size, **kwargs, ) -> Image.Image: - image = source_image.thumbnail((size.width, size.height)) + image = source_image.copy() + image = image.thumbnail((size.width, size.height)) logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height) return image diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py index a7f37d4c..43614d1d 100644 --- a/api/onnx_web/chain/upscale_stable_diffusion.py +++ b/api/onnx_web/chain/upscale_stable_diffusion.py @@ -29,7 +29,7 @@ def load_stable_diffusion( cache_params = (model_path, upscale.format) if last_pipeline_instance is not None and cache_params == last_pipeline_params: - logger.info("reusing existing Stable Diffusion upscale pipeline") + logger.debug("reusing existing Stable Diffusion upscale pipeline") return last_pipeline_instance if upscale.format == "onnx": diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/device_pool.py index 0a5e2026..6ece8322 100644 --- a/api/onnx_web/device_pool.py +++ b/api/onnx_web/device_pool.py @@ -145,7 +145,7 @@ class DevicePoolExecutor: return False - def done(self, key: str) -> Tuple[bool, int]: + def done(self, key: str) -> Tuple[Optional[bool], int]: for job in self.jobs: if job.key == key: done = job.future.done() diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index 525fc292..e8a3b677 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -138,6 +138,7 @@ class UpscaleParams: correction_model: Optional[str] = None, denoise: float = 0.5, faces=True, + face_outscale: int = 1, face_strength: float = 0.5, format: Literal["onnx", "pth"] = "onnx", half=False, @@ -150,6 +151,7 @@ class UpscaleParams: self.correction_model = correction_model self.denoise = denoise self.faces = faces + self.face_outscale = face_outscale self.face_strength = face_strength self.format = format self.half = half @@ -164,6 +166,7 @@ class UpscaleParams: correction_model=self.correction_model, denoise=self.denoise, faces=self.faces, + face_outscale=self.face_outscale, face_strength=self.face_strength, format=self.format, half=self.half, @@ -182,6 +185,7 @@ class UpscaleParams: "correction_model": self.correction_model, "denoise": self.denoise, "faces": self.faces, + "face_outscale": self.face_outscale, "face_strength": self.face_strength, "format": self.format, "half": self.half, diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 419070d2..d4490e7a 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -4,7 +4,7 @@ from glob import glob from io import BytesIO from logging import getLogger from os import makedirs, path -from typing import List, Tuple +from typing import Dict, List, Tuple, Union import numpy as np import torch @@ -80,7 +80,7 @@ from .utils import ( logger = getLogger(__name__) # config caching -config_params = {} +config_params: Dict[str, Dict[str, Union[float, int, str]]] = {} # pipeline params platform_providers = { @@ -136,9 +136,9 @@ chain_stages = { available_platforms: List[DeviceParams] = [] # loaded from model_path -diffusion_models = [] -correction_models = [] -upscaling_models = [] +diffusion_models: List[str] = [] +correction_models: List[str] = [] +upscaling_models: List[str] = [] def get_config_value(key: str, subkey: str = "default", default=None): @@ -269,6 +269,7 @@ def upscale_from_request() -> UpscaleParams: upscaling = get_from_list(request.args, "upscaling", upscaling_models) correction = get_from_list(request.args, "correction", correction_models) faces = get_not_empty(request.args, "faces", "false") == "true" + face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1) face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0) return UpscaleParams( @@ -276,6 +277,7 @@ def upscale_from_request() -> UpscaleParams: correction_model=correction, denoise=denoise, faces=faces, + face_outscale=face_outscale, face_strength=face_strength, format="onnx", outscale=outscale, diff --git a/api/params.json b/api/params.json index 08796608..43331a8d 100644 --- a/api/params.json +++ b/api/params.json @@ -22,6 +22,12 @@ "max": 1, "step": 0.1 }, + "faceOutscale": { + "default": 1, + "min": 1, + "max": 4, + "step": 1 + }, "faceStrength": { "default": 0.5, "min": 0, diff --git a/api/pyproject.toml b/api/pyproject.toml index 4f00474c..efe56db3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -5,3 +5,19 @@ force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_up profile = "black" force_to_top = ".logging" skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion_upscale.py"] + +[tool.mypy] +# ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = [ + "basicsr.archs.rrdbnet_arch", + "boto3", + "codeformer", + "diffusers", + "diffusers.pipeline_utils", + "gfpgan", + "onnxruntime", + "realesrgan" +] +ignore_missing_imports = true \ No newline at end of file diff --git a/docs/dev-test.md b/docs/dev-test.md index 6f139b2a..31b7bd65 100644 --- a/docs/dev-test.md +++ b/docs/dev-test.md @@ -4,26 +4,38 @@ - [Development and Testing](#development-and-testing) - [Contents](#contents) - - [Development](#development) - - [API](#api) - - [GUI](#gui) - - [Updating Github Pages](#updating-github-pages) + - [API Development](#api-development) + - [Style](#style) + - [Models and Pipelines](#models-and-pipelines) + - [GUI Development](#gui-development) + - [Updating Github Pages](#updating-github-pages) - [Testing](#testing) - [Pre-Release Test Plan](#pre-release-test-plan) - [Known Issues](#known-issues) -## Development - -### API +## API Development - TODO: testing - TODO: lint/style -### GUI +### Style + +- all logs must use `logger` from top of file + - every file should have a `logger = getLogger(__name__)` or equivalent before any real code + +### Models and Pipelines + +Loading models and pipelines can be expensive. They should be converted and exported once, then cached per-process +whenever reasonably possible. + +Most pipeline stages will have a corresponding load function somewhere, like `upscale_stable_diffusion` and `load_stable_diffusion`. The load function should compare its parameters and reuse the existing pipeline when +that is possible without causing memory access errors. Most logging from the load function should be `debug` level. + +## GUI Development Run `make ci` to build the bundle. -#### Updating Github Pages +### Updating Github Pages Checkout the `gh-pages` branch and run the `copy-bundle.sh` script, assuming you have the project checked out to a directory named `onnx-web`. diff --git a/docs/user-guide.md b/docs/user-guide.md index d50c09f9..b680f405 100644 --- a/docs/user-guide.md +++ b/docs/user-guide.md @@ -70,6 +70,7 @@ Please see [the server admin guide](server-admin.md) for details on how to confi - [ONNXRuntimeError: The parameter is incorrect](#onnxruntimeerror-the-parameter-is-incorrect) - [The expanded size of the tensor must match the existing size](#the-expanded-size-of-the-tensor-must-match-the-existing-size) - [Shape mismatch attempting to re-use buffer](#shape-mismatch-attempting-to-re-use-buffer) + - [Cannot read properties of undefined (reading 'default')](#cannot-read-properties-of-undefined-reading-default) ## Outline @@ -592,3 +593,22 @@ Example error: [2023-02-04 12:32:54,432] INFO: werkzeug: 10.2.2.16 - - [04/Feb/2023 12:32:54] "GET /api/ready?output=txt2img_1495861691_ccc20fe082567fb4a3471a851db509dc25b4b933dde53db913351be0b617cf85_1 675535574.png HTTP/1.1" 200 - ``` + +#### Cannot read properties of undefined (reading 'default') + +This can happen when you use a newer client with an older version of the server parameters. + +This often means that a parameter is missing from your `params.json` file. If you have not updated your server +recently, try updating and restarting the server. + +If you have customized your `params.json` file, check to make sure it has all of the parameters listed and that the +names are correct (they are case-sensitive). + +Example error: + +```none +Error fetching server parameters +Could not fetch parameters from the ONNX web API server at http://10.2.2.34:5000. + +Cannot read properties of undefined (reading 'default') +``` diff --git a/gui/src/client.ts b/gui/src/client.ts index 9b7930ff..18814d8c 100644 --- a/gui/src/client.ts +++ b/gui/src/client.ts @@ -112,12 +112,13 @@ export interface BrushParams { */ export interface UpscaleParams { enabled: boolean; - denoise: number; - faces: boolean; scale: number; outscale: number; + + faces: boolean; faceStrength: number; + faceOutscale: number; } /** @@ -300,16 +301,14 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) { if (upscale.enabled) { url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT)); url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER)); + url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER)); } if (upscale.faces) { url.searchParams.append('faces', String(upscale.faces)); + url.searchParams.append('faceOutscale', upscale.faceOutscale.toFixed(FIXED_INTEGER)); url.searchParams.append('faceStrength', upscale.faceStrength.toFixed(FIXED_FLOAT)); } - - if (upscale.enabled || upscale.faces) { - url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER)); - } } /** diff --git a/gui/src/components/control/UpscaleControl.tsx b/gui/src/components/control/UpscaleControl.tsx index c1662eb4..98195606 100644 --- a/gui/src/components/control/UpscaleControl.tsx +++ b/gui/src/components/control/UpscaleControl.tsx @@ -56,7 +56,7 @@ export function UpscaleControl() { /> + { + setUpscale({ + faceOutscale, + }); + }} + /> ; } diff --git a/gui/src/state.ts b/gui/src/state.ts index 33882971..9e0c65bf 100644 --- a/gui/src/state.ts +++ b/gui/src/state.ts @@ -374,9 +374,10 @@ export function createStateSlices(server: ServerParams) { denoise: server.denoise.default, enabled: false, faces: false, - scale: server.scale.default, - outscale: server.outscale.default, + faceOutscale: server.faceOutscale.default, faceStrength: server.faceStrength.default, + outscale: server.outscale.default, + scale: server.scale.default, }, upscaleTab: { source: null,