diff --git a/README.md b/README.md index 6274d440..f09979b0 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Based on guides by: - [Note about setup paths](#note-about-setup-paths) - [Create a virtual environment](#create-a-virtual-environment) - [Install pip packages](#install-pip-packages) + - [For upscaling and face correction](#for-upscaling-and-face-correction) - [For AMD on Windows: Install ONNX DirectML](#for-amd-on-windows-install-onnx-directml) - [For CPU on Linux: Install PyTorch CPU](#for-cpu-on-linux-install-pytorch-cpu) - [For CPU on Windows: Install PyTorch CPU](#for-cpu-on-windows-install-pytorch-cpu) @@ -190,6 +191,12 @@ sure you are not using `numpy>=1.24`. [This SO question](https://stackoverflow.com/questions/74844262/how-to-solve-error-numpy-has-no-attribute-float-in-python) has more details. +#### For upscaling and face correction + +```shell +> pip install basicsr facexlib gfpgan realesrgan +``` + #### For AMD on Windows: Install ONNX DirectML If you are running on Windows, install the DirectML ONNX runtime as well: diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index de3c3b88..2adaa866 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -44,6 +44,11 @@ from .image import ( noise_source_uniform, ) +from .upscale import ( + upscale_gfpgan, + upscale_resrgan, +) + import json import numpy as np import time @@ -268,6 +273,8 @@ def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf negative_prompt=negative_prompt, num_inference_steps=steps, ).images[0] + + image = upscale_resrgan(image) image.save(output) print('saved txt2img output: %s' % (output)) diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py new file mode 100644 index 00000000..166b5fad --- /dev/null +++ b/api/onnx_web/upscale.py @@ -0,0 +1,64 @@ +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.utils.download_util import load_file_from_url +from gfpgan import GFPGANer +from os import path +from PIL import Image +from realesrgan import RealESRGANer + +denoise_strength = 0.5 +gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' +resrgan_url = [ + 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] +fp32 = True +model_name = 'RealESRGAN_x4plus' +netscale = 4 +outscale = 4 +pre_pad = 0 +tile = 0 +tile_pad = 10 + + +def upscale_resrgan(source_image: Image) -> Image: + model_path = path.join('weights', model_name + '.pth') + if not path.isfile(model_path): + ROOT_DIR = os.path.dirname(path.abspath(__file__)) + for url in resrgan_url: + model_path = load_file_from_url( + url=url, model_dir=path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) + + model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, + num_block=23, num_grow_ch=32, scale=4) + + dni_weight = None + if model_name == 'realesr-general-x4v3' and denoise_strength != 1: + wdn_model_path = model_path.replace( + 'realesr-general-x4v3', 'realesr-general-wdn-x4v3') + model_path = [model_path, wdn_model_path] + dni_weight = [denoise_strength, 1 - denoise_strength] + + upsampler = RealESRGANer( + scale=netscale, + model_path=model_path, + dni_weight=dni_weight, + model=model, + tile=tile, + tile_pad=tile_pad, + pre_pad=pre_pad, + half=fp32) + + output, _ = upsampler.enhance(source_image, outscale=outscale) + + return upscale_gfpgan(output, upsampler) + + +def upscale_gfpgan(source_image: Image, upsampler) -> Image: + face_enhancer = GFPGANer( + model_path=gfpgan_url, + upscale=outscale, + arch='clean', + channel_multiplier=2, + bg_upsampler=upsampler) + + _, _, output = face_enhancer.enhance(source_image, has_aligned=False, only_center_face=False, paste_back=True) + + return output