feat(api): add basic upscaling
This commit is contained in:
parent
64fac4d7aa
commit
77cb84c60e
|
@ -52,6 +52,7 @@ Based on guides by:
|
||||||
- [Note about setup paths](#note-about-setup-paths)
|
- [Note about setup paths](#note-about-setup-paths)
|
||||||
- [Create a virtual environment](#create-a-virtual-environment)
|
- [Create a virtual environment](#create-a-virtual-environment)
|
||||||
- [Install pip packages](#install-pip-packages)
|
- [Install pip packages](#install-pip-packages)
|
||||||
|
- [For upscaling and face correction](#for-upscaling-and-face-correction)
|
||||||
- [For AMD on Windows: Install ONNX DirectML](#for-amd-on-windows-install-onnx-directml)
|
- [For AMD on Windows: Install ONNX DirectML](#for-amd-on-windows-install-onnx-directml)
|
||||||
- [For CPU on Linux: Install PyTorch CPU](#for-cpu-on-linux-install-pytorch-cpu)
|
- [For CPU on Linux: Install PyTorch CPU](#for-cpu-on-linux-install-pytorch-cpu)
|
||||||
- [For CPU on Windows: Install PyTorch CPU](#for-cpu-on-windows-install-pytorch-cpu)
|
- [For CPU on Windows: Install PyTorch CPU](#for-cpu-on-windows-install-pytorch-cpu)
|
||||||
|
@ -190,6 +191,12 @@ sure you are not using `numpy>=1.24`.
|
||||||
[This SO question](https://stackoverflow.com/questions/74844262/how-to-solve-error-numpy-has-no-attribute-float-in-python)
|
[This SO question](https://stackoverflow.com/questions/74844262/how-to-solve-error-numpy-has-no-attribute-float-in-python)
|
||||||
has more details.
|
has more details.
|
||||||
|
|
||||||
|
#### For upscaling and face correction
|
||||||
|
|
||||||
|
```shell
|
||||||
|
> pip install basicsr facexlib gfpgan realesrgan
|
||||||
|
```
|
||||||
|
|
||||||
#### For AMD on Windows: Install ONNX DirectML
|
#### For AMD on Windows: Install ONNX DirectML
|
||||||
|
|
||||||
If you are running on Windows, install the DirectML ONNX runtime as well:
|
If you are running on Windows, install the DirectML ONNX runtime as well:
|
||||||
|
|
|
@ -44,6 +44,11 @@ from .image import (
|
||||||
noise_source_uniform,
|
noise_source_uniform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .upscale import (
|
||||||
|
upscale_gfpgan,
|
||||||
|
upscale_resrgan,
|
||||||
|
)
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
|
@ -268,6 +273,8 @@ def run_txt2img_pipeline(model, provider, scheduler, prompt, negative_prompt, cf
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
|
image = upscale_resrgan(image)
|
||||||
image.save(output)
|
image.save(output)
|
||||||
|
|
||||||
print('saved txt2img output: %s' % (output))
|
print('saved txt2img output: %s' % (output))
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
from os import path
|
||||||
|
from PIL import Image
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
|
denoise_strength = 0.5
|
||||||
|
gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
|
||||||
|
resrgan_url = [
|
||||||
|
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
||||||
|
fp32 = True
|
||||||
|
model_name = 'RealESRGAN_x4plus'
|
||||||
|
netscale = 4
|
||||||
|
outscale = 4
|
||||||
|
pre_pad = 0
|
||||||
|
tile = 0
|
||||||
|
tile_pad = 10
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_resrgan(source_image: Image) -> Image:
|
||||||
|
model_path = path.join('weights', model_name + '.pth')
|
||||||
|
if not path.isfile(model_path):
|
||||||
|
ROOT_DIR = os.path.dirname(path.abspath(__file__))
|
||||||
|
for url in resrgan_url:
|
||||||
|
model_path = load_file_from_url(
|
||||||
|
url=url, model_dir=path.join(ROOT_DIR, 'weights'), progress=True, file_name=None)
|
||||||
|
|
||||||
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||||
|
num_block=23, num_grow_ch=32, scale=4)
|
||||||
|
|
||||||
|
dni_weight = None
|
||||||
|
if model_name == 'realesr-general-x4v3' and denoise_strength != 1:
|
||||||
|
wdn_model_path = model_path.replace(
|
||||||
|
'realesr-general-x4v3', 'realesr-general-wdn-x4v3')
|
||||||
|
model_path = [model_path, wdn_model_path]
|
||||||
|
dni_weight = [denoise_strength, 1 - denoise_strength]
|
||||||
|
|
||||||
|
upsampler = RealESRGANer(
|
||||||
|
scale=netscale,
|
||||||
|
model_path=model_path,
|
||||||
|
dni_weight=dni_weight,
|
||||||
|
model=model,
|
||||||
|
tile=tile,
|
||||||
|
tile_pad=tile_pad,
|
||||||
|
pre_pad=pre_pad,
|
||||||
|
half=fp32)
|
||||||
|
|
||||||
|
output, _ = upsampler.enhance(source_image, outscale=outscale)
|
||||||
|
|
||||||
|
return upscale_gfpgan(output, upsampler)
|
||||||
|
|
||||||
|
|
||||||
|
def upscale_gfpgan(source_image: Image, upsampler) -> Image:
|
||||||
|
face_enhancer = GFPGANer(
|
||||||
|
model_path=gfpgan_url,
|
||||||
|
upscale=outscale,
|
||||||
|
arch='clean',
|
||||||
|
channel_multiplier=2,
|
||||||
|
bg_upsampler=upsampler)
|
||||||
|
|
||||||
|
_, _, output = face_enhancer.enhance(source_image, has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
|
|
||||||
|
return output
|
Loading…
Reference in New Issue