wrap image output
This commit is contained in:
parent
fe657468bf
commit
30d474b487
|
@ -24,6 +24,10 @@ resrgan_url = [
|
||||||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
||||||
resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
|
resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
|
||||||
|
|
||||||
|
class ONNXImage():
|
||||||
|
def __init__(self, data) -> None:
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
class ONNXNet():
|
class ONNXNet():
|
||||||
'''
|
'''
|
||||||
|
@ -40,7 +44,7 @@ class ONNXNet():
|
||||||
output = self.session.run([output_name], {
|
output = self.session.run([output_name], {
|
||||||
input_name: image.cpu().numpy()
|
input_name: image.cpu().numpy()
|
||||||
})[0]
|
})[0]
|
||||||
return output
|
return ONNXImage(output)
|
||||||
|
|
||||||
def eval(self) -> None:
|
def eval(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -62,8 +66,9 @@ def make_resrgan(model_path):
|
||||||
model_path = load_file_from_url(
|
model_path = load_file_from_url(
|
||||||
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
|
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
|
||||||
|
|
||||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
# model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||||
num_block=23, num_grow_ch=32, scale=4)
|
# num_block=23, num_grow_ch=32, scale=4)
|
||||||
|
model = ONNXNet()
|
||||||
|
|
||||||
dni_weight = None
|
dni_weight = None
|
||||||
if resrgan_name == 'realesr-general-x4v3' and denoise_strength != 1:
|
if resrgan_name == 'realesr-general-x4v3' and denoise_strength != 1:
|
||||||
|
|
Loading…
Reference in New Issue