feat(api): add ONNX implementation of Real ESRGAN net
This commit is contained in:
parent
48963fa591
commit
9519fc16e9
|
@ -1,11 +1,14 @@
|
|||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from gfpgan import GFPGANer
|
||||
from onnxruntime import InferenceSession
|
||||
from os import path
|
||||
from PIL import Image
|
||||
from realesrgan import RealESRGANer
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
denoise_strength = 0.5
|
||||
fp16 = False
|
||||
|
@ -19,10 +22,41 @@ gfpgan_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPG
|
|||
resrgan_name = 'RealESRGAN_x4plus'
|
||||
resrgan_url = [
|
||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
||||
resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
|
||||
|
||||
|
||||
class ONNXNet():
|
||||
'''
|
||||
Provides the RRDBNet interface but using ONNX.
|
||||
'''
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.session = InferenceSession(
|
||||
resrgan_path, providers=['DmlExecutionProvider'])
|
||||
|
||||
def __call__(self, image: Any) -> Any:
|
||||
input_name = self.session.get_inputs()[0].name
|
||||
output_name = self.session.get_outputs()[0].name
|
||||
output = self.session.run([output_name], {
|
||||
input_name: image.cpu().numpy()
|
||||
})[0]
|
||||
return output
|
||||
|
||||
def eval(self) -> None:
|
||||
pass
|
||||
|
||||
def half(self):
|
||||
return self
|
||||
|
||||
def load_state_dict(self) -> None:
|
||||
pass
|
||||
|
||||
def to(self, device):
|
||||
return self
|
||||
|
||||
|
||||
def make_resrgan(model_path):
|
||||
model_path = path.join(model_path, resrgan_name + '.onnx')
|
||||
model_path = path.join(model_path, resrgan_name + '.pth')
|
||||
if not path.isfile(model_path):
|
||||
for url in resrgan_url:
|
||||
model_path = load_file_from_url(
|
||||
|
|
Loading…
Reference in New Issue