diff --git a/README.md b/README.md
index f09979b0..5cfa6477 100644
--- a/README.md
+++ b/README.md
@@ -59,6 +59,7 @@ Based on guides by:
- [For Nvidia everywhere: Install PyTorch GPU and ONNX GPU](#for-nvidia-everywhere-install-pytorch-gpu-and-onnx-gpu)
- [Download and convert models](#download-and-convert-models)
- [Test the models](#test-the-models)
+ - [Upscaling and face correction](#upscaling-and-face-correction)
- [Usage](#usage)
- [Running the containers](#running-the-containers)
- [Configuring and running the server](#configuring-and-running-the-server)
@@ -310,6 +311,13 @@ If the script works, there will be an image of an astronaut in `outputs/test.png
If you get any errors, check [the known errors section](#known-errors-and-solutions).
+### Upscaling and face correction
+
+Models:
+
+- https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
+- https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
+
## Usage
### Running the containers
diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py
index 64726bb5..68b91f21 100644
--- a/api/onnx_web/serve.py
+++ b/api/onnx_web/serve.py
@@ -106,6 +106,12 @@ mask_filters = {
'gaussian-screen': mask_filter_gaussian_screen,
}
+# TODO: load from model_path
+upscale_models = [
+ 'RealESRGAN_x4plus',
+ 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', # TODO: convert GFPGAN
+]
+
def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', bundle_path), filename)
@@ -192,9 +198,17 @@ def border_from_request() -> Border:
def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
+ outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
faces = request.args.get('faces', 'false') == 'true'
- platform = 'onnx'
- return UpscaleParams(scale=scale, faces=faces, platform=platform, denoise=denoise)
+ return UpscaleParams(
+ upscale_models[0],
+ scale=scale,
+ outscale=outscale,
+ faces=faces,
+ face_model=upscale_models[1],
+ platform='onnx',
+ denoise=denoise,
+ )
def check_paths():
if not path.exists(model_path):
diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py
index bf87a449..91cdad92 100644
--- a/api/onnx_web/upscale.py
+++ b/api/onnx_web/upscale.py
@@ -1,11 +1,10 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
-from basicsr.utils.download_util import load_file_from_url
from gfpgan import GFPGANer
from onnxruntime import InferenceSession
from os import path
from PIL import Image
from realesrgan import RealESRGANer
-from typing import Any
+from typing import Any, Union
import numpy as np
import torch
@@ -15,16 +14,9 @@ from .utils import (
)
# TODO: these should all be params or config
-fp16 = False
-outscale = 4
pre_pad = 0
tile_pad = 10
-gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
-resrgan_name = 'RealESRGAN_x4plus'
-resrgan_url = [
- 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
-
class ONNXImage():
def __init__(self, source) -> None:
@@ -51,6 +43,9 @@ class ONNXImage():
def numpy(self):
return self.source
+ def size(self):
+ return np.shape(self.source)
+
class ONNXNet():
'''
@@ -87,29 +82,44 @@ class ONNXNet():
class UpscaleParams():
- def __init__(self, scale=4, faces=True, platform='onnx', denoise=0.5) -> None:
- self.denoise = denoise
+ def __init__(
+ self,
+ upscale_model: str,
+ scale: int = 4,
+ outscale: int = 1,
+ denoise: float = 0.5,
+ faces=True,
+ face_model: Union[str, None] = None,
+ platform: str = 'onnx',
+ half=False
+ ) -> None:
+ self.upscale_model = upscale_model
self.scale = scale
+ self.outscale = outscale
+ self.denoise = denoise
self.faces = faces
+ self.face_model = face_model
self.platform = platform
+ self.half = half
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
- model_path = path.join(ctx.model_path, resrgan_name + '.pth')
+ model_path = path.join(ctx.model_path, '%s.%s' %
+ (params.upscale_model, params.platform))
if not path.isfile(model_path):
- for url in resrgan_url:
- model_path = load_file_from_url(
- url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
+ raise Exception('Real ESRGAN model not found at %s' % model_path)
# use ONNX acceleration, if available
if params.platform == 'onnx':
model = ONNXNet(ctx)
- else:
+ elif params.platform == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale)
+ else:
+ raise Exception('unknown platform %s' % params.platform)
dni_weight = None
- if resrgan_name == 'realesr-general-x4v3' and params.denoise != 1:
+ if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path]
@@ -123,7 +133,7 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
tile=tile,
tile_pad=tile_pad,
pre_pad=pre_pad,
- half=fp16)
+ half=params.half)
return upsampler
@@ -134,8 +144,7 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
image = np.array(source_image)
upsampler = make_resrgan(ctx, params)
- # TODO: what is outscale for here?
- output, _ = upsampler.enhance(image, outscale=outscale)
+ output, _ = upsampler.enhance(image, outscale=params.outscale)
if params.faces:
output = upscale_gfpgan(ctx, params, output)
@@ -144,14 +153,18 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
- print('correcting faces with GFPGAN')
+ print('correcting faces with GFPGAN model: %s' % params.face_model)
+
+ if params.face_model is None:
+ print('no face model given, skipping')
+ return image
if upsampler is None:
upsampler = make_resrgan(ctx, params, tile=512)
face_enhancer = GFPGANer(
- model_path=gfpgan_url,
- upscale=outscale,
+ model_path=params.face_model,
+ upscale=params.outscale,
arch='clean',
channel_multiplier=2,
bg_upsampler=upsampler)
diff --git a/api/params.json b/api/params.json
index 445fbab5..f267e4d7 100644
--- a/api/params.json
+++ b/api/params.json
@@ -61,6 +61,12 @@
"max": 1,
"step": 0.01
},
+ "outscale": {
+ "default": 1,
+ "min": 1,
+ "max": 4,
+ "step": 1
+ },
"width": {
"default": 512,
"min": 64,
diff --git a/gui/src/client.ts b/gui/src/client.ts
index 6e3d28a5..c8f4bd3f 100644
--- a/gui/src/client.ts
+++ b/gui/src/client.ts
@@ -71,6 +71,7 @@ export interface UpscaleParams {
denoise: number;
faces: boolean;
scale: number;
+ outscale: number;
}
export interface ApiResponse {
@@ -170,6 +171,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
url.searchParams.append('faces', String(upscale.faces));
url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER));
+ url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
}
export function makeClient(root: string, f = fetch): ApiClient {
diff --git a/gui/src/components/ImageCard.tsx b/gui/src/components/ImageCard.tsx
index e462e460..b29ff395 100644
--- a/gui/src/components/ImageCard.tsx
+++ b/gui/src/components/ImageCard.tsx
@@ -1,5 +1,5 @@
import { doesExist, mustExist } from '@apextoaster/js-utils';
-import { ContentCopy, ContentCopyTwoTone, Delete, Download } from '@mui/icons-material';
+import { Brush, ContentCopy, ContentCopyTwoTone, Delete, Download } from '@mui/icons-material';
import { Box, Button, Card, CardContent, CardMedia, Grid, Paper } from '@mui/material';
import * as React from 'react';
import { useContext } from 'react';
@@ -86,7 +86,7 @@ export function ImageCard(props: ImageCardProps) {
diff --git a/gui/src/components/ImageInput.tsx b/gui/src/components/ImageInput.tsx
index 224ede34..e41ab9e7 100644
--- a/gui/src/components/ImageInput.tsx
+++ b/gui/src/components/ImageInput.tsx
@@ -24,7 +24,13 @@ export function ImageInput(props: ImageInputProps) {
}
if (doesExist(props.image)) {
- return ;
+ return ;
} else {
return Please select an image.
;
}
diff --git a/gui/src/components/MaskCanvas.tsx b/gui/src/components/MaskCanvas.tsx
index f6b65300..3668d767 100644
--- a/gui/src/components/MaskCanvas.tsx
+++ b/gui/src/components/MaskCanvas.tsx
@@ -1,6 +1,6 @@
import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material';
-import { Button, Stack } from '@mui/material';
+import { Button, Stack, Typography } from '@mui/material';
import { throttle } from 'lodash';
import React, { RefObject, useContext, useEffect, useMemo, useRef, useState } from 'react';
import { useStore } from 'zustand';
@@ -230,6 +230,9 @@ export function MaskCanvas(props: MaskCanvasProps) {
}
}}
/>
+
+ Black pixels in the mask will stay the same, white pixels will be replaced with pixels from the noise source.
+
+ {
+ setUpscale({
+ outscale,
+ });
+ }}
+ />
({
diff --git a/onnx-web.code-workspace b/onnx-web.code-workspace
index b4296ecf..a200b74b 100644
--- a/onnx-web.code-workspace
+++ b/onnx-web.code-workspace
@@ -34,6 +34,7 @@
"Onnx",
"onnxruntime",
"outpaint",
+ "outscale",
"pndm",
"pretrained",
"protobuf",