feat: add outscaling option
This commit is contained in:
parent
091c4e6109
commit
8d3ebede5a
|
@ -59,6 +59,7 @@ Based on guides by:
|
|||
- [For Nvidia everywhere: Install PyTorch GPU and ONNX GPU](#for-nvidia-everywhere-install-pytorch-gpu-and-onnx-gpu)
|
||||
- [Download and convert models](#download-and-convert-models)
|
||||
- [Test the models](#test-the-models)
|
||||
- [Upscaling and face correction](#upscaling-and-face-correction)
|
||||
- [Usage](#usage)
|
||||
- [Running the containers](#running-the-containers)
|
||||
- [Configuring and running the server](#configuring-and-running-the-server)
|
||||
|
@ -310,6 +311,13 @@ If the script works, there will be an image of an astronaut in `outputs/test.png
|
|||
|
||||
If you get any errors, check [the known errors section](#known-errors-and-solutions).
|
||||
|
||||
### Upscaling and face correction
|
||||
|
||||
Models:
|
||||
|
||||
- https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth
|
||||
- https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
|
||||
|
||||
## Usage
|
||||
|
||||
### Running the containers
|
||||
|
|
|
@ -106,6 +106,12 @@ mask_filters = {
|
|||
'gaussian-screen': mask_filter_gaussian_screen,
|
||||
}
|
||||
|
||||
# TODO: load from model_path
|
||||
upscale_models = [
|
||||
'RealESRGAN_x4plus',
|
||||
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', # TODO: convert GFPGAN
|
||||
]
|
||||
|
||||
|
||||
def serve_bundle_file(filename='index.html'):
|
||||
return send_from_directory(path.join('..', bundle_path), filename)
|
||||
|
@ -192,9 +198,17 @@ def border_from_request() -> Border:
|
|||
def upscale_from_request() -> UpscaleParams:
|
||||
denoise = get_and_clamp_float(request.args, 'denoise', 0.5, 1.0, 0.0)
|
||||
scale = get_and_clamp_int(request.args, 'scale', 1, 4, 1)
|
||||
outscale = get_and_clamp_int(request.args, 'outscale', 1, 4, 1)
|
||||
faces = request.args.get('faces', 'false') == 'true'
|
||||
platform = 'onnx'
|
||||
return UpscaleParams(scale=scale, faces=faces, platform=platform, denoise=denoise)
|
||||
return UpscaleParams(
|
||||
upscale_models[0],
|
||||
scale=scale,
|
||||
outscale=outscale,
|
||||
faces=faces,
|
||||
face_model=upscale_models[1],
|
||||
platform='onnx',
|
||||
denoise=denoise,
|
||||
)
|
||||
|
||||
def check_paths():
|
||||
if not path.exists(model_path):
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from gfpgan import GFPGANer
|
||||
from onnxruntime import InferenceSession
|
||||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Any
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -15,16 +14,9 @@ from .utils import (
|
|||
)
|
||||
|
||||
# TODO: these should all be params or config
|
||||
fp16 = False
|
||||
outscale = 4
|
||||
pre_pad = 0
|
||||
tile_pad = 10
|
||||
|
||||
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
|
||||
resrgan_name = 'RealESRGAN_x4plus'
|
||||
resrgan_url = [
|
||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
||||
|
||||
|
||||
class ONNXImage():
|
||||
def __init__(self, source) -> None:
|
||||
|
@ -51,6 +43,9 @@ class ONNXImage():
|
|||
def numpy(self):
|
||||
return self.source
|
||||
|
||||
def size(self):
|
||||
return np.shape(self.source)
|
||||
|
||||
|
||||
class ONNXNet():
|
||||
'''
|
||||
|
@ -87,29 +82,44 @@ class ONNXNet():
|
|||
|
||||
|
||||
class UpscaleParams():
|
||||
def __init__(self, scale=4, faces=True, platform='onnx', denoise=0.5) -> None:
|
||||
self.denoise = denoise
|
||||
def __init__(
|
||||
self,
|
||||
upscale_model: str,
|
||||
scale: int = 4,
|
||||
outscale: int = 1,
|
||||
denoise: float = 0.5,
|
||||
faces=True,
|
||||
face_model: Union[str, None] = None,
|
||||
platform: str = 'onnx',
|
||||
half=False
|
||||
) -> None:
|
||||
self.upscale_model = upscale_model
|
||||
self.scale = scale
|
||||
self.outscale = outscale
|
||||
self.denoise = denoise
|
||||
self.faces = faces
|
||||
self.face_model = face_model
|
||||
self.platform = platform
|
||||
self.half = half
|
||||
|
||||
|
||||
def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
||||
model_path = path.join(ctx.model_path, resrgan_name + '.pth')
|
||||
model_path = path.join(ctx.model_path, '%s.%s' %
|
||||
(params.upscale_model, params.platform))
|
||||
if not path.isfile(model_path):
|
||||
for url in resrgan_url:
|
||||
model_path = load_file_from_url(
|
||||
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
|
||||
raise Exception('Real ESRGAN model not found at %s' % model_path)
|
||||
|
||||
# use ONNX acceleration, if available
|
||||
if params.platform == 'onnx':
|
||||
model = ONNXNet(ctx)
|
||||
else:
|
||||
elif params.platform == 'pth':
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=params.scale)
|
||||
else:
|
||||
raise Exception('unknown platform %s' % params.platform)
|
||||
|
||||
dni_weight = None
|
||||
if resrgan_name == 'realesr-general-x4v3' and params.denoise != 1:
|
||||
if params.upscale_model == 'realesr-general-x4v3' and params.denoise != 1:
|
||||
wdn_model_path = model_path.replace(
|
||||
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
||||
model_path = [model_path, wdn_model_path]
|
||||
|
@ -123,7 +133,7 @@ def make_resrgan(ctx: ServerContext, params: UpscaleParams, tile=0):
|
|||
tile=tile,
|
||||
tile_pad=tile_pad,
|
||||
pre_pad=pre_pad,
|
||||
half=fp16)
|
||||
half=params.half)
|
||||
|
||||
return upsampler
|
||||
|
||||
|
@ -134,8 +144,7 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
|
|||
image = np.array(source_image)
|
||||
upsampler = make_resrgan(ctx, params)
|
||||
|
||||
# TODO: what is outscale for here?
|
||||
output, _ = upsampler.enhance(image, outscale=outscale)
|
||||
output, _ = upsampler.enhance(image, outscale=params.outscale)
|
||||
|
||||
if params.faces:
|
||||
output = upscale_gfpgan(ctx, params, output)
|
||||
|
@ -144,14 +153,18 @@ def upscale_resrgan(ctx: ServerContext, params: UpscaleParams, source_image: Ima
|
|||
|
||||
|
||||
def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=None) -> Image:
|
||||
print('correcting faces with GFPGAN')
|
||||
print('correcting faces with GFPGAN model: %s' % params.face_model)
|
||||
|
||||
if params.face_model is None:
|
||||
print('no face model given, skipping')
|
||||
return image
|
||||
|
||||
if upsampler is None:
|
||||
upsampler = make_resrgan(ctx, params, tile=512)
|
||||
|
||||
face_enhancer = GFPGANer(
|
||||
model_path=gfpgan_url,
|
||||
upscale=outscale,
|
||||
model_path=params.face_model,
|
||||
upscale=params.outscale,
|
||||
arch='clean',
|
||||
channel_multiplier=2,
|
||||
bg_upsampler=upsampler)
|
||||
|
|
|
@ -61,6 +61,12 @@
|
|||
"max": 1,
|
||||
"step": 0.01
|
||||
},
|
||||
"outscale": {
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 4,
|
||||
"step": 1
|
||||
},
|
||||
"width": {
|
||||
"default": 512,
|
||||
"min": 64,
|
||||
|
|
|
@ -71,6 +71,7 @@ export interface UpscaleParams {
|
|||
denoise: number;
|
||||
faces: boolean;
|
||||
scale: number;
|
||||
outscale: number;
|
||||
}
|
||||
|
||||
export interface ApiResponse {
|
||||
|
@ -170,6 +171,7 @@ export function appendUpscaleToURL(url: URL, upscale: UpscaleParams) {
|
|||
url.searchParams.append('denoise', upscale.denoise.toFixed(FIXED_FLOAT));
|
||||
url.searchParams.append('faces', String(upscale.faces));
|
||||
url.searchParams.append('scale', upscale.scale.toFixed(FIXED_INTEGER));
|
||||
url.searchParams.append('outscale', upscale.outscale.toFixed(FIXED_INTEGER));
|
||||
}
|
||||
|
||||
export function makeClient(root: string, f = fetch): ApiClient {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import { doesExist, mustExist } from '@apextoaster/js-utils';
|
||||
import { ContentCopy, ContentCopyTwoTone, Delete, Download } from '@mui/icons-material';
|
||||
import { Brush, ContentCopy, ContentCopyTwoTone, Delete, Download } from '@mui/icons-material';
|
||||
import { Box, Button, Card, CardContent, CardMedia, Grid, Paper } from '@mui/material';
|
||||
import * as React from 'react';
|
||||
import { useContext } from 'react';
|
||||
|
@ -86,7 +86,7 @@ export function ImageCard(props: ImageCardProps) {
|
|||
</GridItem>
|
||||
<GridItem xs={2}>
|
||||
<Button onClick={copySourceToInpaint}>
|
||||
<ContentCopyTwoTone />
|
||||
<Brush />
|
||||
</Button>
|
||||
</GridItem>
|
||||
<GridItem xs={2}>
|
||||
|
|
|
@ -24,7 +24,13 @@ export function ImageInput(props: ImageInputProps) {
|
|||
}
|
||||
|
||||
if (doesExist(props.image)) {
|
||||
return <img src={URL.createObjectURL(props.image)} />;
|
||||
return <img
|
||||
src={URL.createObjectURL(props.image)}
|
||||
style={{
|
||||
maxWidth: 512,
|
||||
maxHeight: 512,
|
||||
}}
|
||||
/>;
|
||||
} else {
|
||||
return <div>Please select an image.</div>;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import { doesExist, Maybe, mustExist } from '@apextoaster/js-utils';
|
||||
import { FormatColorFill, Gradient } from '@mui/icons-material';
|
||||
import { Button, Stack } from '@mui/material';
|
||||
import { Button, Stack, Typography } from '@mui/material';
|
||||
import { throttle } from 'lodash';
|
||||
import React, { RefObject, useContext, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useStore } from 'zustand';
|
||||
|
@ -230,6 +230,9 @@ export function MaskCanvas(props: MaskCanvasProps) {
|
|||
}
|
||||
}}
|
||||
/>
|
||||
<Typography variant='body1'>
|
||||
Black pixels in the mask will stay the same, white pixels will be replaced with pixels from the noise source.
|
||||
</Typography>
|
||||
<Stack direction='row' spacing={4}>
|
||||
<NumericField
|
||||
label='Brush Color'
|
||||
|
|
|
@ -48,6 +48,19 @@ export function UpscaleControl(props: UpscaleControlProps) {
|
|||
});
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label='Outscale'
|
||||
disabled={params.enabled === false}
|
||||
min={config.outscale.min}
|
||||
max={config.outscale.max}
|
||||
step={config.outscale.step}
|
||||
value={params.outscale}
|
||||
onChange={(outscale) => {
|
||||
setUpscale({
|
||||
outscale,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
<NumericField
|
||||
label='Denoise'
|
||||
decimal
|
||||
|
|
|
@ -241,6 +241,7 @@ export function createStateSlices(base: ConfigParams) {
|
|||
enabled: false,
|
||||
faces: false,
|
||||
scale: 1,
|
||||
outscale: 1,
|
||||
},
|
||||
setUpscale(upscale) {
|
||||
set((prev) => ({
|
||||
|
|
|
@ -34,6 +34,7 @@
|
|||
"Onnx",
|
||||
"onnxruntime",
|
||||
"outpaint",
|
||||
"outscale",
|
||||
"pndm",
|
||||
"pretrained",
|
||||
"protobuf",
|
||||
|
|
Loading…
Reference in New Issue