feat: add face outscale as its own parameter (#111)
This commit is contained in:
parent
564cfc1279
commit
de4e7b0dc9
|
@ -4,6 +4,7 @@ import numpy as np
|
|||
import torch
|
||||
from diffusers import OnnxStableDiffusionImg2ImgPipeline
|
||||
from PIL import Image
|
||||
from typing import Optional
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..diffusion.load import load_pipeline
|
||||
|
@ -21,7 +22,7 @@ def blend_img2img(
|
|||
source_image: Image.Image,
|
||||
*,
|
||||
strength: float,
|
||||
prompt: str = None,
|
||||
prompt: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
prompt = prompt or params.prompt
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -25,7 +25,7 @@ def blend_inpaint(
|
|||
source_image: Image.Image,
|
||||
*,
|
||||
expand: Border,
|
||||
mask_image: Image.Image = None,
|
||||
mask_image: Optional[Image.Image] = None,
|
||||
fill_color: str = "white",
|
||||
mask_filter: Callable = mask_filter_none,
|
||||
noise_source: Callable = noise_source_histogram,
|
||||
|
|
|
@ -4,7 +4,7 @@ from codeformer import CodeFormer
|
|||
from PIL import Image
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import ImageParams, StageParams
|
||||
from ..params import ImageParams, StageParams, UpscaleParams
|
||||
from ..utils import ServerContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -20,11 +20,12 @@ def correct_codeformer(
|
|||
source: Image.Image,
|
||||
*,
|
||||
source_image: Image.Image = None,
|
||||
upscale: UpscaleParams,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
device = job.get_device()
|
||||
# TODO: terrible names, fix
|
||||
image = source or source_image
|
||||
|
||||
pipe = CodeFormer(upscale=stage.outscale).to(device.torch_device())
|
||||
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device())
|
||||
return pipe(image)
|
||||
|
|
|
@ -4,23 +4,25 @@ from os import path
|
|||
import numpy as np
|
||||
from gfpgan import GFPGANer
|
||||
from PIL import Image
|
||||
from typing import Optional
|
||||
|
||||
from ..device_pool import JobContext
|
||||
from ..params import DeviceParams, ImageParams, StageParams, UpscaleParams
|
||||
from ..utils import ServerContext, run_gc
|
||||
from .upscale_resrgan import load_resrgan
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
last_pipeline_instance = None
|
||||
last_pipeline_params = None
|
||||
last_pipeline_instance: Optional[GFPGANer] = None
|
||||
last_pipeline_params: Optional[str] = None
|
||||
|
||||
|
||||
def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParams):
|
||||
def load_gfpgan(server: ServerContext, stage: StageParams, upscale: UpscaleParams, device: DeviceParams):
|
||||
global last_pipeline_instance
|
||||
global last_pipeline_params
|
||||
|
||||
face_path = path.join(ctx.model_path, "%s.pth" % (upscale.correction_model))
|
||||
face_path = path.join(server.model_path, "%s.pth" % (upscale.correction_model))
|
||||
|
||||
if last_pipeline_instance is not None and face_path == last_pipeline_params:
|
||||
logger.info("reusing existing GFPGAN pipeline")
|
||||
|
@ -28,12 +30,15 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParam
|
|||
|
||||
logger.debug("loading GFPGAN model from %s", face_path)
|
||||
|
||||
upsampler = load_resrgan(server, upscale, device, tile=stage.tile_size)
|
||||
|
||||
# TODO: find a way to pass the ONNX model to underlying architectures
|
||||
gfpgan = GFPGANer(
|
||||
model_path=face_path,
|
||||
upscale=upscale.outscale,
|
||||
arch="clean",
|
||||
bg_upsampler=upsampler,
|
||||
channel_multiplier=2,
|
||||
model_path=face_path,
|
||||
upscale=upscale.face_outscale,
|
||||
)
|
||||
|
||||
last_pipeline_instance = gfpgan
|
||||
|
@ -46,7 +51,7 @@ def load_gfpgan(ctx: ServerContext, upscale: UpscaleParams, _device: DeviceParam
|
|||
def correct_gfpgan(
|
||||
job: JobContext,
|
||||
server: ServerContext,
|
||||
_stage: StageParams,
|
||||
stage: StageParams,
|
||||
_params: ImageParams,
|
||||
source_image: Image.Image,
|
||||
*,
|
||||
|
@ -59,7 +64,7 @@ def correct_gfpgan(
|
|||
|
||||
logger.info("correcting faces with GFPGAN model: %s", upscale.correction_model)
|
||||
device = job.get_device()
|
||||
gfpgan = load_gfpgan(server, upscale, device)
|
||||
gfpgan = load_gfpgan(server, stage, upscale, device)
|
||||
|
||||
output = np.array(source_image)
|
||||
_, _, output = gfpgan.enhance(
|
||||
|
|
|
@ -19,6 +19,7 @@ def reduce_thumbnail(
|
|||
size: Size,
|
||||
**kwargs,
|
||||
) -> Image.Image:
|
||||
image = source_image.thumbnail((size.width, size.height))
|
||||
image = source_image.copy()
|
||||
image = image.thumbnail((size.width, size.height))
|
||||
logger.info("created thumbnail with dimensions: %sx%s", image.width, image.height)
|
||||
return image
|
||||
|
|
|
@ -29,7 +29,7 @@ def load_stable_diffusion(
|
|||
cache_params = (model_path, upscale.format)
|
||||
|
||||
if last_pipeline_instance is not None and cache_params == last_pipeline_params:
|
||||
logger.info("reusing existing Stable Diffusion upscale pipeline")
|
||||
logger.debug("reusing existing Stable Diffusion upscale pipeline")
|
||||
return last_pipeline_instance
|
||||
|
||||
if upscale.format == "onnx":
|
||||
|
|
|
@ -145,7 +145,7 @@ class DevicePoolExecutor:
|
|||
|
||||
return False
|
||||
|
||||
def done(self, key: str) -> Tuple[bool, int]:
|
||||
def done(self, key: str) -> Tuple[Optional[bool], int]:
|
||||
for job in self.jobs:
|
||||
if job.key == key:
|
||||
done = job.future.done()
|
||||
|
|
|
@ -138,6 +138,7 @@ class UpscaleParams:
|
|||
correction_model: Optional[str] = None,
|
||||
denoise: float = 0.5,
|
||||
faces=True,
|
||||
face_outscale: int = 1,
|
||||
face_strength: float = 0.5,
|
||||
format: Literal["onnx", "pth"] = "onnx",
|
||||
half=False,
|
||||
|
@ -150,6 +151,7 @@ class UpscaleParams:
|
|||
self.correction_model = correction_model
|
||||
self.denoise = denoise
|
||||
self.faces = faces
|
||||
self.face_outscale = face_outscale
|
||||
self.face_strength = face_strength
|
||||
self.format = format
|
||||
self.half = half
|
||||
|
@ -164,6 +166,7 @@ class UpscaleParams:
|
|||
correction_model=self.correction_model,
|
||||
denoise=self.denoise,
|
||||
faces=self.faces,
|
||||
face_outscale=self.face_outscale,
|
||||
face_strength=self.face_strength,
|
||||
format=self.format,
|
||||
half=self.half,
|
||||
|
@ -182,6 +185,7 @@ class UpscaleParams:
|
|||
"correction_model": self.correction_model,
|
||||
"denoise": self.denoise,
|
||||
"faces": self.faces,
|
||||
"face_outscale": self.face_outscale,
|
||||
"face_strength": self.face_strength,
|
||||
"format": self.format,
|
||||
"half": self.half,
|
||||
|
|
|
@ -4,7 +4,7 @@ from glob import glob
|
|||
from io import BytesIO
|
||||
from logging import getLogger
|
||||
from os import makedirs, path
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -80,7 +80,7 @@ from .utils import (
|
|||
logger = getLogger(__name__)
|
||||
|
||||
# config caching
|
||||
config_params = {}
|
||||
config_params: Dict[str, Dict[str, Union[float, int, str]]] = {}
|
||||
|
||||
# pipeline params
|
||||
platform_providers = {
|
||||
|
@ -136,9 +136,9 @@ chain_stages = {
|
|||
available_platforms: List[DeviceParams] = []
|
||||
|
||||
# loaded from model_path
|
||||
diffusion_models = []
|
||||
correction_models = []
|
||||
upscaling_models = []
|
||||
diffusion_models: List[str] = []
|
||||
correction_models: List[str] = []
|
||||
upscaling_models: List[str] = []
|
||||
|
||||
|
||||
def get_config_value(key: str, subkey: str = "default", default=None):
|
||||
|
@ -269,6 +269,7 @@ def upscale_from_request() -> UpscaleParams:
|
|||
upscaling = get_from_list(request.args, "upscaling", upscaling_models)
|
||||
correction = get_from_list(request.args, "correction", correction_models)
|
||||
faces = get_not_empty(request.args, "faces", "false") == "true"
|
||||
face_outscale = get_and_clamp_int(request.args, "faceOutscale", 1, 4, 1)
|
||||
face_strength = get_and_clamp_float(request.args, "faceStrength", 0.5, 1.0, 0.0)
|
||||
|
||||
return UpscaleParams(
|
||||
|
@ -276,6 +277,7 @@ def upscale_from_request() -> UpscaleParams:
|
|||
correction_model=correction,
|
||||
denoise=denoise,
|
||||
faces=faces,
|
||||
face_outscale=face_outscale,
|
||||
face_strength=face_strength,
|
||||
format="onnx",
|
||||
outscale=outscale,
|
||||
|
|
|
@ -22,6 +22,12 @@
|
|||
"max": 1,
|
||||
"step": 0.1
|
||||
},
|
||||
"faceOutscale": {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 4,
|
||||
"step": 1
|
||||
},
|
||||
"faceStrength": {
|
||||
"default": 0.5,
|
||||
"min": 0,
|
||||
|
|
|
@ -5,3 +5,19 @@ force-exclude = '''/(lpw_stable_diffusion_onnx|pipeline_onnx_stable_diffusion_up
|
|||
profile = "black"
|
||||
force_to_top = ".logging"
|
||||
skip_glob = ["*/lpw_stable_diffusion_onnx.py", "*/pipeline_onnx_stable_diffusion_upscale.py"]
|
||||
|
||||
[tool.mypy]
|
||||
# ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"basicsr.archs.rrdbnet_arch",
|
||||
"boto3",
|
||||
"codeformer",
|
||||
"diffusers",
|
||||
"diffusers.pipeline_utils",
|
||||
"gfpgan",
|
||||
"onnxruntime",
|
||||
"realesrgan"
|
||||
]
|
||||
ignore_missing_imports = true
|
|
@ -4,26 +4,38 @@
|
|||
|
||||
- [Development and Testing](#development-and-testing)
|
||||
- [Contents](#contents)
|
||||
- [Development](#development)
|
||||
- [API](#api)
|
||||
- [GUI](#gui)
|
||||
- [API Development](#api-development)
|
||||
- [Style](#style)
|
||||
- [Models and Pipelines](#models-and-pipelines)
|
||||
- [GUI Development](#gui-development)
|
||||
- [Updating Github Pages](#updating-github-pages)
|
||||
- [Testing](#testing)
|
||||
- [Pre-Release Test Plan](#pre-release-test-plan)
|
||||
- [Known Issues](#known-issues)
|
||||
|
||||
## Development
|
||||
|
||||
### API
|
||||
## API Development
|
||||
|
||||
- TODO: testing
|
||||
- TODO: lint/style
|
||||
|
||||
### GUI
|
||||
### Style
|
||||
|
||||
- all logs must use `logger` from top of file
|
||||
- every file should have a `logger = getLogger(__name__)` or equivalent before any real code
|
||||
|
||||
### Models and Pipelines
|
||||
|
||||
Loading models and pipelines can be expensive. They should be converted and exported once, then cached per-process
|
||||
whenever reasonably possible.
|
||||
|
||||
Most pipeline stages will have a corresponding load function somewhere, like `upscale_stable_diffusion` and `load_stable_diffusion`. The load function should compare its parameters and reuse the existing pipeline when
|
||||
that is possible without causing memory access errors. Most logging from the load function should be `debug` level.
|
||||
|
||||
## GUI Development
|
||||
|
||||
Run `make ci` to build the bundle.
|
||||
|
||||
#### Updating Github Pages
|
||||
### Updating Github Pages
|
||||
|
||||
Checkout the `gh-pages` branch and run the `copy-bundle.sh` script, assuming you have the project
|
||||
checked out to a directory named `onnx-web`.
|
||||
|
|
|
@ -70,6 +70,7 @@ Please see [the server admin guide](server-admin.md) for details on how to confi
|
|||
- [ONNXRuntimeError: The parameter is incorrect](#onnxruntimeerror-the-parameter-is-incorrect)
|
||||
- [The expanded size of the tensor must match the existing size](#the-expanded-size-of-the-tensor-must-match-the-existing-size)
|
||||
- [Shape mismatch attempting to re-use buffer](#shape-mismatch-attempting-to-re-use-buffer)
|
||||
- [Cannot read properties of undefined (reading 'default')](#cannot-read-properties-of-undefined-reading-default)
|
||||
|
||||
## Outline
|
||||
|
||||
|
@ -592,3 +593,22 @@ Example error:
|
|||
[2023-02-04 12:32:54,432] INFO: werkzeug: 10.2.2.16 - - [04/Feb/2023 12:32:54] "GET /api/ready?output=txt2img_1495861691_ccc20fe082567fb4a3471a851db509dc25b4b933dde53db913351be0b617cf85_1
|
||||
675535574.png HTTP/1.1" 200 -
|
||||
```
|
||||
|
||||
#### Cannot read properties of undefined (reading 'default')
|
||||
|
||||
This can happen when you use a newer client with an older version of the server parameters.
|
||||
|
||||
This often means that a parameter is missing from your `params.json` file. If you have not updated your server
|
||||
recently, try updating and restarting the server.
|
||||
|
||||
If you have customized your `params.json` file, check to make sure it has all of the parameters listed and that the
|
||||
names are correct (they are case-sensitive).
|
||||
|
||||
Example error:
|
||||
|
||||
```none
|
||||
Error fetching server parameters
|
||||
Could not fetch parameters from the ONNX web API server at http://10.2.2.34:5000.
|
||||
|
||||
Cannot read properties of undefined (reading 'default')
|
||||
```
|
||||
|
|
|
@ -112,12 +112,13 @@ export interface BrushParams {
|
|||
*/
|
||||
export interface UpscaleParams {
|
||||
enabled: boolean;
|
||||
|
||||
denoise: number;
|
||||
faces: boolean;
|
||||
scale: number;
|
||||
outscale: number;
|
||||
|
||||
faces: boolean;
|
||||
faceStrength: number;
|
||||
faceOutscale: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -300,16 +301,14 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
|
|||
if (upscale.enabled) {
|
||||
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
|
||||
url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER));
|
||||
url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
|
||||
}
|
||||
|
||||
if (upscale.faces) {
|
||||
url.searchParams.append('faces', String(upscale.faces));
|
||||
url.searchParams.append('faceOutscale', upscale.faceOutscale.toFixed(FIXED_INTEGER));
|
||||
url.searchParams.append('faceStrength', upscale.faceStrength.toFixed(FIXED_FLOAT));
|
||||
}
|
||||
|
||||
if (upscale.enabled || upscale.faces) {
|
||||
url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -56,7 +56,7 @@ export function UpscaleControl() {
|
|||
/>
|
||||
<NumericField
|
||||
label='Outscale'
|
||||
disabled={upscale.enabled === false && upscale.faces === false}
|
||||
disabled={upscale.enabled === false}
|
||||
min={params.outscale.min}
|
||||
max={params.outscale.max}
|
||||
step={params.outscale.step}
|
||||
|
@ -93,5 +93,18 @@ export function UpscaleControl() {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label='Outscale'
|
||||
disabled={upscale.faces === false}
|
||||
min={params.faceOutscale.min}
|
||||
max={params.faceOutscale.max}
|
||||
step={params.faceOutscale.step}
|
||||
value={upscale.faceOutscale}
|
||||
onChange={(faceOutscale) => {
|
||||
setUpscale({
|
||||
faceOutscale,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</Stack>;
|
||||
}
|
||||
|
|
|
@ -374,9 +374,10 @@ export function createStateSlices(server: ServerParams) {
|
|||
denoise: server.denoise.default,
|
||||
enabled: false,
|
||||
faces: false,
|
||||
scale: server.scale.default,
|
||||
outscale: server.outscale.default,
|
||||
faceOutscale: server.faceOutscale.default,
|
||||
faceStrength: server.faceStrength.default,
|
||||
outscale: server.outscale.default,
|
||||
scale: server.scale.default,
|
||||
},
|
||||
upscaleTab: {
|
||||
source: null,
|
||||
|
|
Loading…
Reference in New Issue