diff --git a/api/onnx_web/upscale.py b/api/onnx_web/upscale.py index 22dd5fc5..4b0b6430 100644 --- a/api/onnx_web/upscale.py +++ b/api/onnx_web/upscale.py @@ -24,6 +24,10 @@ resrgan_url = [ 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth'] resrgan_path = path.join('..', 'models', 'RealESRGAN_x4plus.onnx') +class ONNXImage(): + def __init__(self, data) -> None: + self.data = data + class ONNXNet(): ''' @@ -40,7 +44,7 @@ class ONNXNet(): output = self.session.run([output_name], { input_name: image.cpu().numpy() })[0] - return output + return ONNXImage(output) def eval(self) -> None: pass @@ -62,8 +66,9 @@ def make_resrgan(model_path): model_path = load_file_from_url( url=url, model_dir=path.join(model_path, resrgan_name), progress=True, file_name=None) - model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, - num_block=23, num_grow_ch=32, scale=4) + # model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, + # num_block=23, num_grow_ch=32, scale=4) + model = ONNXNet() dni_weight = None if resrgan_name == 'realesr-general-x4v3' and denoise_strength != 1: