1
0
Fork 0

feat: add outscaling option

This commit is contained in:
Sean Sube 2023-01-16 14:52:56 -06:00
parent 091c4e6109
commit 8d3ebede5a
11 changed files with 96 additions and 29 deletions

View File

@ -59,6 +59,7 @@ Based on guides by:
- [For Nvidia everywhere: Install PyTorch GPU and ONNX GPU](#for-nvidia-everywhere-install-pytorch-gpu-and-onnx-gpu) - [For Nvidia everywhere: Install PyTorch GPU and ONNX GPU](#for-nvidia-everywhere-install-pytorch-gpu-and-onnx-gpu)
- [Download and convert models](#download-and-convert-models) - [Download and convert models](#download-and-convert-models)
- [Test the models](#test-the-models) - [Test the models](#test-the-models)
- [Upscaling and face correction](#upscaling-and-face-correction)
- [Usage](#usage) - [Usage](#usage)
- [Running the containers](#running-the-containers) - [Running the containers](#running-the-containers)
- [Configuring and running the server](#configuring-and-running-the-server) - [Configuring and running the server](#configuring-and-running-the-server)
@ -310,6 +311,13 @@ If the script works, there will be an image of an astronaut in `outputs/test.png
If you get any errors, check [the known errors section](#known-errors-and-solutions). If you get any errors, check [the known errors section](#known-errors-and-solutions).
### Upscaling and face correction
Models:
- https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
- https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
## Usage ## Usage
### Running the containers ### Running the containers

View File

@ -106,6 +106,12 @@ mask_filters = {
'gaussian-screen': mask_filter_gaussian_screen, 'gaussian-screen': mask_filter_gaussian_screen,
} }
# TODO: load from model_path
upscale_models = [
'RealESRGAN_x4plus',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', # TODO: convert GFPGAN
]
def serve_bundle_file(filename='index.html'): def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', bundle_path), filename) return send_from_directory(path.join('..', bundle_path), filename)
@ -192,9 +198,17 @@ def border_from_request() -> Border:
def upscale_from_request() -> UpscaleParams: def upscale_from_request() -> UpscaleParams:
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0) denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1) scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
faces = request.args.get('faces', 'false') == 'true' faces = request.args.get('faces', 'false') == 'true'
platform = 'onnx' return UpscaleParams(
return UpscaleParams(scale=scale, faces=faces, platform=platform, denoise=denoise) upscale_models[0],
scale=scale,
outscale=outscale,
faces=faces,
face_model=upscale_models[1],
platform='onnx',
denoise=denoise,
)
def check_paths(): def check_paths():
if not path.exists(model_path): if not path.exists(model_path):

View File

@ -1,11 +1,10 @@
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from gfpgan import GFPGANer from gfpgan import GFPGANer
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from os import path from os import path
from PIL import Image from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from typing import Any from typing import Any, Union
import numpy as np import numpy as np
import torch import torch
@ -15,16 +14,9 @@ from .utils import (
) )
# TODO: these should all be params or config # TODO: these should all be params or config
fp16 = False
outscale = 4
pre_pad = 0 pre_pad = 0
tile_pad = 10 tile_pad = 10
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
resrgan_name = 'RealESRGAN_x4plus'
resrgan_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
class ONNXImage(): class ONNXImage():
def __init__(self, source) -> None: def __init__(self, source) -> None:
@ -51,6 +43,9 @@ class ONNXImage():
def numpy(self): def numpy(self):
return self.source return self.source
def size(self):
return np.shape(self.source)
class ONNXNet(): class ONNXNet():
''' '''
@ -87,29 +82,44 @@ class ONNXNet():
class UpscaleParams(): class UpscaleParams():
def __init__(self, scale=4, faces=True, platform='onnx', denoise=0.5) -> None: def __init__(
self.denoise = denoise self,
upscale_model: str,
scale: int = 4,
outscale: int = 1,
denoise: float = 0.5,
faces=True,
face_model: Union[str, None] = None,
platform: str = 'onnx',
half=False
) -> None:
self.upscale_model = upscale_model
self.scale = scale self.scale = scale
self.outscale = outscale
self.denoise = denoise
self.faces = faces self.faces = faces
self.face_model = face_model
self.platform = platform self.platform = platform
self.half = half
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0): def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
model_path = path.join(ctx.model_path, resrgan_name + '.pth') model_path = path.join(ctx.model_path, '%s.%s' %
(params.upscale_model, params.platform))
if not path.isfile(model_path): if not path.isfile(model_path):
for url in resrgan_url: raise Exception('Real ESRGAN model not found at %s' % model_path)
model_path = load_file_from_url(
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
# use ONNX acceleration, if available # use ONNX acceleration, if available
if params.platform == 'onnx': if params.platform == 'onnx':
model = ONNXNet(ctx) model = ONNXNet(ctx)
else: elif params.platform == 'pth':
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=params.scale) num_block=23, num_grow_ch=32, scale=params.scale)
else:
raise Exception('unknown platform %s' % params.platform)
dni_weight = None dni_weight = None
if resrgan_name == 'realesr-general-x4v3' and params.denoise != 1: if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
wdn_model_path = model_path.replace( wdn_model_path = model_path.replace(
'realesr-general-x4v3', 'realesr-general-wdn-x4v3') 'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
model_path = [model_path, wdn_model_path] model_path = [model_path, wdn_model_path]
@ -123,7 +133,7 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
tile=tile, tile=tile,
tile_pad=tile_pad, tile_pad=tile_pad,
pre_pad=pre_pad, pre_pad=pre_pad,
half=fp16) half=params.half)
return upsampler return upsampler
@ -134,8 +144,7 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
image = np.array(source_image) image = np.array(source_image)
upsampler = make_resrgan(ctx, params) upsampler = make_resrgan(ctx, params)
# TODO: what is outscale for here? output, _ = upsampler.enhance(image, outscale=params.outscale)
output, _ = upsampler.enhance(image, outscale=outscale)
if params.faces: if params.faces:
output = upscale_gfpgan(ctx, params, output) output = upscale_gfpgan(ctx, params, output)
@ -144,14 +153,18 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image: def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
print('correcting faces with GFPGAN') print('correcting faces with GFPGAN model: %s' % params.face_model)
if params.face_model is None:
print('no face model given, skipping')
return image
if upsampler is None: if upsampler is None:
upsampler = make_resrgan(ctx, params, tile=512) upsampler = make_resrgan(ctx, params, tile=512)
face_enhancer = GFPGANer( face_enhancer = GFPGANer(
model_path=gfpgan_url, model_path=params.face_model,
upscale=outscale, upscale=params.outscale,
arch='clean', arch='clean',
channel_multiplier=2, channel_multiplier=2,
bg_upsampler=upsampler) bg_upsampler=upsampler)

View File

@ -61,6 +61,12 @@
"max": 1, "max": 1,
"step": 0.01 "step": 0.01
}, },
"outscale": {
"default": 1,
"min": 1,
"max": 4,
"step": 1
},
"width": { "width": {
"default": 512, "default": 512,
"min": 64, "min": 64,

View File

@ -71,6 +71,7 @@ export interface UpscaleParams {
denoise: number; denoise: number;
faces: boolean; faces: boolean;
scale: number; scale: number;
outscale: number;
} }
export interface ApiResponse { export interface ApiResponse {
@ -170,6 +171,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT)); url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
url.searchParams.append('faces', String(upscale.faces)); url.searchParams.append('faces', String(upscale.faces));
url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER)); url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER));
url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
} }
export function makeClient(root: string, f = fetch): ApiClient { export function makeClient(root: string, f = fetch): ApiClient {

View File

@ -1,5 +1,5 @@
import { doesExist, mustExist } from '@apextoaster/js-utils'; import { doesExist, mustExist } from '@apextoaster/js-utils';
import { ContentCopy, ContentCopyTwoTone, Delete, Download } from '@mui/icons-material'; import { Brush, ContentCopy, ContentCopyTwoTone, Delete, Download } from '@mui/icons-material';
import { Box, Button, Card, CardContent, CardMedia, Grid, Paper } from '@mui/material'; import { Box, Button, Card, CardContent, CardMedia, Grid, Paper } from '@mui/material';
import * as React from 'react'; import * as React from 'react';
import { useContext } from 'react'; import { useContext } from 'react';
@ -86,7 +86,7 @@ export function ImageCard(props: ImageCardProps) {
</GridItem> </GridItem>
<GridItem xs={2}> <GridItem xs={2}>
<Button onClick={copySourceToInpaint}> <Button onClick={copySourceToInpaint}>
<ContentCopyTwoTone /> <Brush />
</Button> </Button>
</GridItem> </GridItem>
<GridItem xs={2}> <GridItem xs={2}>

View File

@ -24,7 +24,13 @@ export function ImageInput(props: ImageInputProps) {
} }
if (doesExist(props.image)) { if (doesExist(props.image)) {
return <img src={URL.createObjectURL(props.image)} />; return <img
src={URL.createObjectURL(props.image)}
style={{
maxWidth: 512,
maxHeight: 512,
}}
/>;
} else { } else {
return <div>Please select an image.</div>; return <div>Please select an image.</div>;
} }

View File

@ -1,6 +1,6 @@
import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils'; import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
import { FormatColorFill, Gradient } from '@mui/icons-material'; import { FormatColorFill, Gradient } from '@mui/icons-material';
import { Button, Stack } from '@mui/material'; import { Button, Stack, Typography } from '@mui/material';
import { throttle } from 'lodash'; import { throttle } from 'lodash';
import React, { RefObject, useContext, useEffect, useMemo, useRef, useState } from 'react'; import React, { RefObject, useContext, useEffect, useMemo, useRef, useState } from 'react';
import { useStore } from 'zustand'; import { useStore } from 'zustand';
@ -230,6 +230,9 @@ export function MaskCanvas(props: MaskCanvasProps) {
} }
}} }}
/> />
<Typography variant='body1'>
Black pixels in the mask will stay the same, white pixels will be replaced with pixels from the noise source.
</Typography>
<Stack direction='row' spacing={4}> <Stack direction='row' spacing={4}>
<NumericField <NumericField
label='Brush Color' label='Brush Color'

View File

@ -48,6 +48,19 @@ export function UpscaleControl(props: UpscaleControlProps) {
}); });
}} }}
/> />
<NumericField
label='Outscale'
disabled={params.enabled === false}
min={config.outscale.min}
max={config.outscale.max}
step={config.outscale.step}
value={params.outscale}
onChange={(outscale) => {
setUpscale({
outscale,
});
}}
/>
<NumericField <NumericField
label='Denoise' label='Denoise'
decimal decimal

View File

@ -241,6 +241,7 @@ export function createStateSlices(base: ConfigParams) {
enabled: false, enabled: false,
faces: false, faces: false,
scale: 1, scale: 1,
outscale: 1,
}, },
setUpscale(upscale) { setUpscale(upscale) {
set((prev) => ({ set((prev) => ({

View File

@ -34,6 +34,7 @@
"Onnx", "Onnx",
"onnxruntime", "onnxruntime",
"outpaint", "outpaint",
"outscale",
"pndm", "pndm",
"pretrained", "pretrained",
"protobuf", "protobuf",