wrap image output
This commit is contained in:
parent
fe657468bf
commit
30d474b487
|
@ -24,6 +24,10 @@ resrgan_url = [
|
|||
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth']
|
||||
resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx')
|
||||
|
||||
class ONNXImage():
|
||||
def __init__(self, data) -> None:
|
||||
self.data = data
|
||||
|
||||
|
||||
class ONNXNet():
|
||||
'''
|
||||
|
@ -40,7 +44,7 @@ class ONNXNet():
|
|||
output = self.session.run([output_name], {
|
||||
input_name: image.cpu().numpy()
|
||||
})[0]
|
||||
return output
|
||||
return ONNXImage(output)
|
||||
|
||||
def eval(self) -> None:
|
||||
pass
|
||||
|
@ -62,8 +66,9 @@ def make_resrgan(model_path):
|
|||
model_path = load_file_from_url(
|
||||
url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None)
|
||||
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=4)
|
||||
# model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
# num_block=23, num_grow_ch=32, scale=4)
|
||||
model = ONNXNet()
|
||||
|
||||
dni_weight = None
|
||||
if resrgan_name == 'realesr-general-x4v3' and denoise_strength != 1:
|
||||
|
|
Loading…
Reference in New Issue