1
0
Fork 0

feat(api): add ONNX implementation of Real ESRGAN net

This commit is contained in:
Sean Sube 2023-01-16 10:55:40 -06:00
parent 48963fa591
commit 9519fc16e9
1 changed files with 35 additions and 1 deletions

View File

@ -1,11 +1,14 @@
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from gfpgan import GFPGANer
from onnxruntime import InferenceSession
from os import path
from PIL import Image
from realesrgan import RealESRGANer
from typing import Any
import numpy as np
import torch
denoise_strength = 0.5
fp16 = False
@ -19,10 +22,41 @@ gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPG
resrgan_name = 'RealESRGAN_x4plus'
resrgan_url = [
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
class ONNXNet():
'''
Provides the RRDBNet interface but using ONNX.
'''
def __init__(self) -> None:
self.session = InferenceSession(
resrgan_path, providers=['DmlExecutionProvider'])
def __call__(self, image: Any) -> Any:
input_name = self.session.get_inputs()[0].name
output_name = self.session.get_outputs()[0].name
output = self.session.run([output_name], {
input_name: image.cpu().numpy()
})[0]
return output
def eval(self) -> None:
pass
def half(self):
return self
def load_state_dict(self) -> None:
pass
def to(self, device):
return self
def make_resrgan(model_path):
model_path = path.join(model_path, resrgan_name + '.onnx')
model_path = path.join(model_path, resrgan_name + '.pth')
if not path.isfile(model_path):
for url in resrgan_url:
model_path = load_file_from_url(