1
0
Fork 0

feat: add face outscale as its own parameter (#111)

This commit is contained in:
Sean Sube 2023-02-06 17:13:37 -06:00
parent 564cfc1279
commit de4e7b0dc9
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
16 changed files with 120 additions and 39 deletions

View File

@ -4,6 +4,7 @@ import numpy as np
import torch
from diffusers import OnnxStableDiffusionImg2ImgPipeline
from PIL import Image
from typing import Optional
from ..device_pool import JobContext
from ..diffusion.load import load_pipeline
@ -21,7 +22,7 @@ def blend_img2img(
source_image: Image.Image,
*,
strength: float,
prompt: str = None,
prompt: Optional[str] = None,
**kwargs,
) -> Image.Image:
prompt = prompt or params.prompt

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple
import numpy as np
import torch
@ -25,7 +25,7 @@ def blend_inpaint(
source_image: Image.Image,
*,
expand: Border,
mask_image: Image.Image = None,
mask_image: Optional[Image.Image] = None,
fill_color: str = "white",
mask_filter: Callable = mask_filter_none,
noise_source: Callable = noise_source_histogram,

View File

@ -4,7 +4,7 @@ from codeformer import CodeFormer
from PIL import Image
from ..device_pool import JobContext
from ..params import ImageParams, StageParams
from ..params import ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext
logger = getLogger(__name__)
@ -20,11 +20,12 @@ def correct_codeformer(
source: Image.Image,
*,
source_image: Image.Image = None,
upscale: UpscaleParams,
**kwargs,
) -> Image.Image:
device = job.get_device()
# TODO: terrible names, fix
image = source or source_image
pipe = CodeFormer(upscale=stage.outscale).to(device.torch_device())
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device())
return pipe(image)

View File

@ -4,23 +4,25 @@ from os import path
import numpy as np
from gfpgan import GFPGANer
from PIL import Image
from typing import Optional
from ..device_pool import JobContext
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
from ..utils import ServerContext, run_gc
from .upscale_resrgan import load_resrgan
logger = getLogger(__name__)
last_pipeline_instance = None
last_pipeline_params = None
last_pipeline_instance: Optional[GFPGANer] = None
last_pipeline_params: Optional[str] = None
def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParams):
def load_gfpgan(server: ServerContext, stage: StageParams, upscale: UpscaleParams, device: DeviceParams):
global last_pipeline_instance
global last_pipeline_params
face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model))
face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
if last_pipeline_instance is not None and face_path == last_pipeline_params:
logger.info("reusing existing GFPGAN pipeline")
@ -28,12 +30,15 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParam
logger.debug("loading GFPGAN model from %s", face_path)
upsampler = load_resrgan(server, upscale, device, tile=stage.tile_size)
# TODO: find a way to pass the ONNX model to underlying architectures
gfpgan = GFPGANer(
model_path=face_path,
upscale=upscale.outscale,
arch="clean",
bg_upsampler=upsampler,
channel_multiplier=2,
model_path=face_path,
upscale=upscale.face_outscale,
)
last_pipeline_instance = gfpgan
@ -46,7 +51,7 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParam
def correct_gfpgan(
job: JobContext,
server: ServerContext,
_stage: StageParams,
stage: StageParams,
_params: ImageParams,
source_image: Image.Image,
*,
@ -59,7 +64,7 @@ def correct_gfpgan(
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
device = job.get_device()
gfpgan = load_gfpgan(server, upscale, device)
gfpgan = load_gfpgan(server, stage, upscale, device)
output = np.array(source_image)
_, _, output = gfpgan.enhance(

View File

@ -19,6 +19,7 @@ def reduce_thumbnail(
size: Size,
**kwargs,
) -> Image.Image:
image = source_image.thumbnail((size.width, size.height))
image = source_image.copy()
image = image.thumbnail((size.width, size.height))
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
return image

View File

@ -29,7 +29,7 @@ def load_stable_diffusion(
cache_params = (model_path, upscale.format)
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
logger.info("reusing existing Stable Diffusion upscale pipeline")
logger.debug("reusing existing Stable Diffusion upscale pipeline")
return last_pipeline_instance
if upscale.format == "onnx":

View File

@ -145,7 +145,7 @@ class DevicePoolExecutor:
return False
def done(self, key: str) -> Tuple[bool, int]:
def done(self, key: str) -> Tuple[Optional[bool], int]:
for job in self.jobs:
if job.key == key:
done = job.future.done()

View File

@ -138,6 +138,7 @@ class UpscaleParams:
correction_model: Optional[str] = None,
denoise: float = 0.5,
faces=True,
face_outscale: int = 1,
face_strength: float = 0.5,
format: Literal["onnx", "pth"] = "onnx",
half=False,
@ -150,6 +151,7 @@ class UpscaleParams:
self.correction_model = correction_model
self.denoise = denoise
self.faces = faces
self.face_outscale = face_outscale
self.face_strength = face_strength
self.format = format
self.half = half
@ -164,6 +166,7 @@ class UpscaleParams:
correction_model=self.correction_model,
denoise=self.denoise,
faces=self.faces,
face_outscale=self.face_outscale,
face_strength=self.face_strength,
format=self.format,
half=self.half,
@ -182,6 +185,7 @@ class UpscaleParams:
"correction_model": self.correction_model,
"denoise": self.denoise,
"faces": self.faces,
"face_outscale": self.face_outscale,
"face_strength": self.face_strength,
"format": self.format,
"half": self.half,

View File

@ -4,7 +4,7 @@ from glob import glob
from io import BytesIO
from logging import getLogger
from os import makedirs, path
from typing import List, Tuple
from typing import Dict, List, Tuple, Union
import numpy as np
import torch
@ -80,7 +80,7 @@ from .utils import (
logger = getLogger(__name__)
# config caching
config_params = {}
config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
# pipeline params
platform_providers = {
@ -136,9 +136,9 @@ chain_stages = {
available_platforms: List[DeviceParams] = []
# loaded from model_path
diffusion_models = []
correction_models = []
upscaling_models = []
diffusion_models: List[str] = []
correction_models: List[str] = []
upscaling_models: List[str] = []
def get_config_value(key: str, subkey: str = "default", default=None):
@ -269,6 +269,7 @@ def upscale_from_request() -> UpscaleParams:
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
correction = get_from_list(request.args, "correction", correction_models)
faces = get_not_empty(request.args, "faces", "false") == "true"
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
return UpscaleParams(
@ -276,6 +277,7 @@ def upscale_from_request() -> UpscaleParams:
correction_model=correction,
denoise=denoise,
faces=faces,
face_outscale=face_outscale,
face_strength=face_strength,
format="onnx",
outscale=outscale,

View File

@ -22,6 +22,12 @@
"max": 1,
"step": 0.1
},
"faceOutscale": {
"default": 1,
"min": 1,
"max": 4,
"step": 1
},
"faceStrength": {
"default": 0.5,
"min": 0,

View File

@ -5,3 +5,19 @@ force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_up
profile = "black"
force_to_top = ".logging"
skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion_upscale.py"]
[tool.mypy]
# ignore_missing_imports = true
[[tool.mypy.overrides]]
module = [
"basicsr.archs.rrdbnet_arch",
"boto3",
"codeformer",
"diffusers",
"diffusers.pipeline_utils",
"gfpgan",
"onnxruntime",
"realesrgan"
]
ignore_missing_imports = true

View File

@ -4,26 +4,38 @@
- [Development and Testing](#development-and-testing)
- [Contents](#contents)
- [Development](#development)
- [API](#api)
- [GUI](#gui)
- [Updating Github Pages](#updating-github-pages)
- [API Development](#api-development)
- [Style](#style)
- [Models and Pipelines](#models-and-pipelines)
- [GUI Development](#gui-development)
- [Updating Github Pages](#updating-github-pages)
- [Testing](#testing)
- [Pre-Release Test Plan](#pre-release-test-plan)
- [Known Issues](#known-issues)
## Development
### API
## API Development
- TODO: testing
- TODO: lint/style
### GUI
### Style
- all logs must use `logger` from top of file
- every file should have a `logger = getLogger(__name__)` or equivalent before any real code
### Models and Pipelines
Loading models and pipelines can be expensive. They should be converted and exported once, then cached per-process
whenever reasonably possible.
Most pipeline stages will have a corresponding load function somewhere, like `upscale_stable_diffusion` and `load_stable_diffusion`. The load function should compare its parameters and reuse the existing pipeline when
that is possible without causing memory access errors. Most logging from the load function should be `debug` level.
## GUI Development
Run `make ci` to build the bundle.
#### Updating Github Pages
### Updating Github Pages
Checkout the `gh-pages` branch and run the `copy-bundle.sh` script, assuming you have the project
checked out to a directory named `onnx-web`.

View File

@ -70,6 +70,7 @@ Please see [the server admin guide](server-admin.md) for details on how to confi
- [ONNXRuntimeError: The parameter is incorrect](#onnxruntimeerror-the-parameter-is-incorrect)
- [The expanded size of the tensor must match the existing size](#the-expanded-size-of-the-tensor-must-match-the-existing-size)
- [Shape mismatch attempting to re-use buffer](#shape-mismatch-attempting-to-re-use-buffer)
- [Cannot read properties of undefined (reading 'default')](#cannot-read-properties-of-undefined-reading-default)
## Outline
@ -592,3 +593,22 @@ Example error:
[2023-02-04 12:32:54,432] INFO: werkzeug: 10.2.2.16 - - [04/Feb/2023 12:32:54] "GET /api/ready?output=txt2img_1495861691_ccc20fe082567fb4a3471a851db509dc25b4b933dde53db913351be0b617cf85_1
675535574.png HTTP/1.1" 200 -
```
#### Cannot read properties of undefined (reading 'default')
This can happen when you use a newer client with an older version of the server parameters.
This often means that a parameter is missing from your `params.json` file. If you have not updated your server
recently, try updating and restarting the server.
If you have customized your `params.json` file, check to make sure it has all of the parameters listed and that the
names are correct (they are case-sensitive).
Example error:
```none
Error fetching server parameters
Could not fetch parameters from the ONNX web API server at http://10.2.2.34:5000.
Cannot read properties of undefined (reading 'default')
```

View File

@ -112,12 +112,13 @@ export interface BrushParams {
*/
export interface UpscaleParams {
enabled: boolean;
denoise: number;
faces: boolean;
scale: number;
outscale: number;
faces: boolean;
faceStrength: number;
faceOutscale: number;
}
/**
@ -300,16 +301,14 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
if (upscale.enabled) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER));
url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
}
if (upscale.faces) {
url.searchParams.append('faces', String(upscale.faces));
url.searchParams.append('faceOutscale', upscale.faceOutscale.toFixed(FIXED_INTEGER));
url.searchParams.append('faceStrength', upscale.faceStrength.toFixed(FIXED_FLOAT));
}
if (upscale.enabled || upscale.faces) {
url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
}
}
/**

View File

@ -56,7 +56,7 @@ export function UpscaleControl() {
/>
<NumericField
label='Outscale'
disabled={upscale.enabled === false && upscale.faces === false}
disabled={upscale.enabled === false}
min={params.outscale.min}
max={params.outscale.max}
step={params.outscale.step}
@ -93,5 +93,18 @@ export function UpscaleControl() {
});
}}
/>
<NumericField
label='Outscale'
disabled={upscale.faces === false}
min={params.faceOutscale.min}
max={params.faceOutscale.max}
step={params.faceOutscale.step}
value={upscale.faceOutscale}
onChange={(faceOutscale) => {
setUpscale({
faceOutscale,
});
}}
/>
</Stack>;
}

View File

@ -374,9 +374,10 @@ export function createStateSlices(server: ServerParams) {
denoise: server.denoise.default,
enabled: false,
faces: false,
scale: server.scale.default,
outscale: server.outscale.default,
faceOutscale: server.faceOutscale.default,
faceStrength: server.faceStrength.default,
outscale: server.outscale.default,
scale: server.scale.default,
},
upscaleTab: {
source: null,