diff --git a/api/onnx_web/chain/blend_img2img.py b/api/onnx_web/chain/blend_img2img.py
index d56c5cb6..a66b0fd6 100644
--- a/api/onnx_web/chain/blend_img2img.py
+++ b/api/onnx_web/chain/blend_img2img.py
@@ -4,6 +4,7 @@ import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
+from typing import Optional
from ..device_pool import JobContext
from ..diffusion.load import load_pipeline
@@ -21,7 +22,7 @@ def blend_img2img(
source_image: Image.Image,
*,
strength: float,
- prompt: str = None,
+ prompt: Optional[str] = None,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt
diff --git a/api/onnx_web/chain/blend_inpaint.py b/api/onnx_web/chain/blend_inpaint.py
index 9ecba374..a980e5a5 100644
--- a/api/onnx_web/chain/blend_inpaint.py
+++ b/api/onnx_web/chain/blend_inpaint.py
@@ -1,5 +1,5 @@
from logging import getLogger
-from typing import Callable, Tuple
+from typing import Callable, Optional, Tuple
import numpy as np
import torch
@@ -25,7 +25,7 @@ def blend_inpaint(
source_image: Image.Image,
*,
expand: Border,
- mask_image: Image.Image = None,
+ mask_image: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,
diff --git a/api/onnx_web/chain/correct_codeformer.py b/api/onnx_web/chain/correct_codeformer.py
index 60bea425..5dc4bdd8 100644
--- a/api/onnx_web/chain/correct_codeformer.py
+++ b/api/onnx_web/chain/correct_codeformer.py
@@ -4,7 +4,7 @@ from codeformer import CodeFormer
from PIL import Image
from ..device_pool import JobContext
-from ..params import ImageParams, StageParams
+from ..params import ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext
logger = getLogger(__name__)
@@ -20,11 +20,12 @@ def correct_codeformer(
source: Image.Image,
*,
source_image: Image.Image = None,
+ upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
device = job.get_device()
# TODO: terrible names, fix
image = source or source_image
- pipe = CodeFormer(upscale=stage.outscale).to(device.torch_device())
+ pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device())
return pipe(image)
diff --git a/api/onnx_web/chain/correct_gfpgan.py b/api/onnx_web/chain/correct_gfpgan.py
index 4022c06c..5b9018de 100644
--- a/api/onnx_web/chain/correct_gfpgan.py
+++ b/api/onnx_web/chain/correct_gfpgan.py
@@ -4,23 +4,25 @@ from os import path
import numpy as np
from gfpgan import GFPGANer
from PIL import Image
+from typing import Optional
from ..device_pool import JobContext
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc
+from .upscale_resrgan import load_resrgan
logger = getLogger(__name__)
-last_pipeline_instance = None
-last_pipeline_params = None
+last_pipeline_instance: Optional[GFPGANer] = None
+last_pipeline_params: Optional[str] = None
-def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParams):
+def load_gfpgan(server: ServerContext, stage: StageParams, upscale: UpscaleParams, device: DeviceParams):
global last_pipeline_instance
global last_pipeline_params
- face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model))
+ face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
if last_pipeline_instance is not None and face_path == last_pipeline_params:
logger.info("reusing existing GFPGAN pipeline")
@@ -28,12 +30,15 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParam
logger.debug("loading GFPGAN model from %s", face_path)
+ upsampler = load_resrgan(server, upscale, device, tile=stage.tile_size)
+
# TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer(
- model_path=face_path,
- upscale=upscale.outscale,
arch="clean",
+ bg_upsampler=upsampler,
channel_multiplier=2,
+ model_path=face_path,
+ upscale=upscale.face_outscale,
)
last_pipeline_instance = gfpgan
@@ -46,7 +51,7 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParam
def correct_gfpgan(
job: JobContext,
server: ServerContext,
- _stage: StageParams,
+ stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
@@ -59,7 +64,7 @@ def correct_gfpgan(
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device()
- gfpgan = load_gfpgan(server, upscale, device)
+ gfpgan = load_gfpgan(server, stage, upscale, device)
output = np.array(source_image)
_, _, output = gfpgan.enhance(
diff --git a/api/onnx_web/chain/reduce_thumbnail.py b/api/onnx_web/chain/reduce_thumbnail.py
index b5b25130..b6beb9c6 100644
--- a/api/onnx_web/chain/reduce_thumbnail.py
+++ b/api/onnx_web/chain/reduce_thumbnail.py
@@ -19,6 +19,7 @@ def reduce_thumbnail(
size: Size,
**kwargs,
) -> Image.Image:
- image = source_image.thumbnail((size.width, size.height))
+ image = source_image.copy()
+ image = image.thumbnail((size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image
diff --git a/api/onnx_web/chain/upscale_stable_diffusion.py b/api/onnx_web/chain/upscale_stable_diffusion.py
index a7f37d4c..43614d1d 100644
--- a/api/onnx_web/chain/upscale_stable_diffusion.py
+++ b/api/onnx_web/chain/upscale_stable_diffusion.py
@@ -29,7 +29,7 @@ def load_stable_diffusion(
cache_params = (model_path, upscale.format)
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
- logger.info("reusing existing Stable Diffusion upscale pipeline")
+ logger.debug("reusing existing Stable Diffusion upscale pipeline")
return last_pipeline_instance
if upscale.format == "onnx":
diff --git a/api/onnx_web/device_pool.py b/api/onnx_web/device_pool.py
index 0a5e2026..6ece8322 100644
--- a/api/onnx_web/device_pool.py
+++ b/api/onnx_web/device_pool.py
@@ -145,7 +145,7 @@ class DevicePoolExecutor:
return False
- def done(self, key: str) -> Tuple[bool, int]:
+ def done(self, key: str) -> Tuple[Optional[bool], int]:
for job in self.jobs:
if job.key == key:
done = job.future.done()
diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py
index 525fc292..e8a3b677 100644
--- a/api/onnx_web/params.py
+++ b/api/onnx_web/params.py
@@ -138,6 +138,7 @@ class UpscaleParams:
correction_model: Optional[str] = None,
denoise: float = 0.5,
faces=True,
+ face_outscale: int = 1,
face_strength: float = 0.5,
format: Literal["onnx", "pth"] = "onnx",
half=False,
@@ -150,6 +151,7 @@ class UpscaleParams:
self.correction_model = correction_model
self.denoise = denoise
self.faces = faces
+ self.face_outscale = face_outscale
self.face_strength = face_strength
self.format = format
self.half = half
@@ -164,6 +166,7 @@ class UpscaleParams:
correction_model=self.correction_model,
denoise=self.denoise,
faces=self.faces,
+ face_outscale=self.face_outscale,
face_strength=self.face_strength,
format=self.format,
half=self.half,
@@ -182,6 +185,7 @@ class UpscaleParams:
"correction_model": self.correction_model,
"denoise": self.denoise,
"faces": self.faces,
+ "face_outscale": self.face_outscale,
"face_strength": self.face_strength,
"format": self.format,
"half": self.half,
diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py
index 419070d2..d4490e7a 100644
--- a/api/onnx_web/serve.py
+++ b/api/onnx_web/serve.py
@@ -4,7 +4,7 @@ from glob import glob
from io import BytesIO
from logging import getLogger
from os import makedirs, path
-from typing import List, Tuple
+from typing import Dict, List, Tuple, Union
import numpy as np
import torch
@@ -80,7 +80,7 @@ from .utils import (
logger = getLogger(__name__)
# config caching
-config_params = {}
+config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
# pipeline params
platform_providers = {
@@ -136,9 +136,9 @@ chain_stages = {
available_platforms: List[DeviceParams] = []
# loaded from model_path
-diffusion_models = []
-correction_models = []
-upscaling_models = []
+diffusion_models: List[str] = []
+correction_models: List[str] = []
+upscaling_models: List[str] = []
def get_config_value(key: str, subkey: str = "default", default=None):
@@ -269,6 +269,7 @@ def upscale_from_request() -> UpscaleParams:
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
correction = get_from_list(request.args, "correction", correction_models)
faces = get_not_empty(request.args, "faces", "false") == "true"
+ face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
return UpscaleParams(
@@ -276,6 +277,7 @@ def upscale_from_request() -> UpscaleParams:
correction_model=correction,
denoise=denoise,
faces=faces,
+ face_outscale=face_outscale,
face_strength=face_strength,
format="onnx",
outscale=outscale,
diff --git a/api/params.json b/api/params.json
index 08796608..43331a8d 100644
--- a/api/params.json
+++ b/api/params.json
@@ -22,6 +22,12 @@
"max": 1,
"step": 0.1
},
+ "faceOutscale": {
+ "default": 1,
+ "min": 1,
+ "max": 4,
+ "step": 1
+ },
"faceStrength": {
"default": 0.5,
"min": 0,
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 4f00474c..efe56db3 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -5,3 +5,19 @@ force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_up
profile = "black"
force_to_top = ".logging"
skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion_upscale.py"]
+
+[tool.mypy]
+# ignore_missing_imports = true
+
+[[tool.mypy.overrides]]
+module = [
+ "basicsr.archs.rrdbnet_arch",
+ "boto3",
+ "codeformer",
+ "diffusers",
+ "diffusers.pipeline_utils",
+ "gfpgan",
+ "onnxruntime",
+ "realesrgan"
+]
+ignore_missing_imports = true
\ No newline at end of file
diff --git a/docs/dev-test.md b/docs/dev-test.md
index 6f139b2a..31b7bd65 100644
--- a/docs/dev-test.md
+++ b/docs/dev-test.md
@@ -4,26 +4,38 @@
- [Development and Testing](#development-and-testing)
- [Contents](#contents)
- - [Development](#development)
- - [API](#api)
- - [GUI](#gui)
- - [Updating Github Pages](#updating-github-pages)
+ - [API Development](#api-development)
+ - [Style](#style)
+ - [Models and Pipelines](#models-and-pipelines)
+ - [GUI Development](#gui-development)
+ - [Updating Github Pages](#updating-github-pages)
- [Testing](#testing)
- [Pre-Release Test Plan](#pre-release-test-plan)
- [Known Issues](#known-issues)
-## Development
-
-### API
+## API Development
- TODO: testing
- TODO: lint/style
-### GUI
+### Style
+
+- all logs must use `logger` from top of file
+ - every file should have a `logger = getLogger(__name__)` or equivalent before any real code
+
+### Models and Pipelines
+
+Loading models and pipelines can be expensive. They should be converted and exported once, then cached per-process
+whenever reasonably possible.
+
+Most pipeline stages will have a corresponding load function somewhere, like `upscale_stable_diffusion` and `load_stable_diffusion`. The load function should compare its parameters and reuse the existing pipeline when
+that is possible without causing memory access errors. Most logging from the load function should be `debug` level.
+
+## GUI Development
Run `make ci` to build the bundle.
-#### Updating Github Pages
+### Updating Github Pages
Checkout the `gh-pages` branch and run the `copy-bundle.sh` script, assuming you have the project
checked out to a directory named `onnx-web`.
diff --git a/docs/user-guide.md b/docs/user-guide.md
index d50c09f9..b680f405 100644
--- a/docs/user-guide.md
+++ b/docs/user-guide.md
@@ -70,6 +70,7 @@ Please see [the server admin guide](server-admin.md) for details on how to confi
- [ONNXRuntimeError: The parameter is incorrect](#onnxruntimeerror-the-parameter-is-incorrect)
- [The expanded size of the tensor must match the existing size](#the-expanded-size-of-the-tensor-must-match-the-existing-size)
- [Shape mismatch attempting to re-use buffer](#shape-mismatch-attempting-to-re-use-buffer)
+ - [Cannot read properties of undefined (reading 'default')](#cannot-read-properties-of-undefined-reading-default)
## Outline
@@ -592,3 +593,22 @@ Example error:
[2023-02-04 12:32:54,432] INFO: werkzeug: 10.2.2.16 - - [04/Feb/2023 12:32:54] "GET /api/ready?output=txt2img_1495861691_ccc20fe082567fb4a3471a851db509dc25b4b933dde53db913351be0b617cf85_1
675535574.png HTTP/1.1" 200 -
```
+
+#### Cannot read properties of undefined (reading 'default')
+
+This can happen when you use a newer client with an older version of the server parameters.
+
+This often means that a parameter is missing from your `params.json` file. If you have not updated your server
+recently, try updating and restarting the server.
+
+If you have customized your `params.json` file, check to make sure it has all of the parameters listed and that the
+names are correct (they are case-sensitive).
+
+Example error:
+
+```none
+Error fetching server parameters
+Could not fetch parameters from the ONNX web API server at http://10.2.2.34:5000.
+
+Cannot read properties of undefined (reading 'default')
+```
diff --git a/gui/src/client.ts b/gui/src/client.ts
index 9b7930ff..18814d8c 100644
--- a/gui/src/client.ts
+++ b/gui/src/client.ts
@@ -112,12 +112,13 @@ export interface BrushParams {
*/
export interface UpscaleParams {
enabled: boolean;
-
denoise: number;
- faces: boolean;
scale: number;
outscale: number;
+
+ faces: boolean;
faceStrength: number;
+ faceOutscale: number;
}
/**
@@ -300,16 +301,14 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
if (upscale.enabled) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER));
+ url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
}
if (upscale.faces) {
url.searchParams.append('faces', String(upscale.faces));
+ url.searchParams.append('faceOutscale', upscale.faceOutscale.toFixed(FIXED_INTEGER));
url.searchParams.append('faceStrength', upscale.faceStrength.toFixed(FIXED_FLOAT));
}
-
- if (upscale.enabled || upscale.faces) {
- url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
- }
}
/**
diff --git a/gui/src/components/control/UpscaleControl.tsx b/gui/src/components/control/UpscaleControl.tsx
index c1662eb4..98195606 100644
--- a/gui/src/components/control/UpscaleControl.tsx
+++ b/gui/src/components/control/UpscaleControl.tsx
@@ -56,7 +56,7 @@ export function UpscaleControl() {
/>
+ {
+ setUpscale({
+ faceOutscale,
+ });
+ }}
+ />
;
}
diff --git a/gui/src/state.ts b/gui/src/state.ts
index 33882971..9e0c65bf 100644
--- a/gui/src/state.ts
+++ b/gui/src/state.ts
@@ -374,9 +374,10 @@ export function createStateSlices(server: ServerParams) {
denoise: server.denoise.default,
enabled: false,
faces: false,
- scale: server.scale.default,
- outscale: server.outscale.default,
+ faceOutscale: server.faceOutscale.default,
faceStrength: server.faceStrength.default,
+ outscale: server.outscale.default,
+ scale: server.scale.default,
},
upscaleTab: {
source: null,