From 2c9d96d2eec8f4f042d76d1d7db281a44c0d5e15 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 17 Feb 2023 23:25:42 -0600 Subject: [PATCH] feat(api): use ONNX for Real ESRGAN v3 model --- api/onnx_web/chain/upscale_resrgan.py | 37 ++++++++++++++------------- api/onnx_web/onnx/__init__.py | 2 +- api/onnx_web/onnx/onnx_net.py | 8 +++--- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 484314ba..0c60984a 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -37,16 +37,6 @@ def load_resrgan( if not path.isfile(model_path): raise Exception("Real ESRGAN model not found at %s" % model_path) - if x4_v3_tag in model_file: - # the x4-v3 model needs a different network - model = SRVGGNetCompact( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_conv=32, - upscale=4, - act_type="prelu", - ) elif params.format == "onnx": # use ONNX acceleration, if available model = OnnxNet( @@ -56,14 +46,25 @@ def load_resrgan( sess_options=device.sess_options(), ) elif params.format == "pth": - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=params.scale, - ) + if x4_v3_tag in model_file: + # the x4-v3 model needs a different network + model = SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=4, + act_type="prelu", + ) + else: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=params.scale, + ) else: raise Exception("unknown platform %s" % params.format) diff --git a/api/onnx_web/onnx/__init__.py b/api/onnx_web/onnx/__init__.py index 9d30760b..8e58b60a 100644 --- a/api/onnx_web/onnx/__init__.py +++ b/api/onnx_web/onnx/__init__.py @@ -1 +1 @@ -from .onnx_net import OnnxImage, OnnxNet +from .onnx_net import OnnxTensor, OnnxNet diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 51697380..f6001f62 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -8,7 +8,7 @@ from onnxruntime import InferenceSession, SessionOptions from ..utils import ServerContext -class OnnxImage: +class OnnxTensor: def __init__(self, source) -> None: self.source = source self.data = self @@ -57,8 +57,10 @@ class OnnxNet: def __call__(self, image: Any) -> Any: input_name = self.session.get_inputs()[0].name output_name = self.session.get_outputs()[0].name - output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0] - return OnnxImage(output) + output = self.session.run([output_name], { + input_name: image.cpu().numpy() + })[0] + return OnnxTensor(output) def eval(self) -> None: pass