From 2c9d96d2eec8f4f042d76d1d7db281a44c0d5e15 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Fri, 17 Feb 2023 23:25:42 -0600 Subject: [PATCH 1/3] feat(api): use ONNX for Real ESRGAN v3 model --- api/onnx_web/chain/upscale_resrgan.py | 37 ++++++++++++++------------- api/onnx_web/onnx/__init__.py | 2 +- api/onnx_web/onnx/onnx_net.py | 8 +++--- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 484314ba..0c60984a 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -37,16 +37,6 @@ def load_resrgan( if not path.isfile(model_path): raise Exception("Real ESRGAN model not found at %s" % model_path) - if x4_v3_tag in model_file: - # the x4-v3 model needs a different network - model = SRVGGNetCompact( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_conv=32, - upscale=4, - act_type="prelu", - ) elif params.format == "onnx": # use ONNX acceleration, if available model = OnnxNet( @@ -56,14 +46,25 @@ def load_resrgan( sess_options=device.sess_options(), ) elif params.format == "pth": - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=params.scale, - ) + if x4_v3_tag in model_file: + # the x4-v3 model needs a different network + model = SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=4, + act_type="prelu", + ) + else: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=params.scale, + ) else: raise Exception("unknown platform %s" % params.format) diff --git a/api/onnx_web/onnx/__init__.py b/api/onnx_web/onnx/__init__.py index 9d30760b..8e58b60a 100644 --- a/api/onnx_web/onnx/__init__.py +++ b/api/onnx_web/onnx/__init__.py @@ -1 +1 @@ -from .onnx_net import OnnxImage, OnnxNet +from .onnx_net import OnnxTensor, OnnxNet diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 51697380..f6001f62 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -8,7 +8,7 @@ from onnxruntime import InferenceSession, SessionOptions from ..utils import ServerContext -class OnnxImage: +class OnnxTensor: def __init__(self, source) -> None: self.source = source self.data = self @@ -57,8 +57,10 @@ class OnnxNet: def __call__(self, image: Any) -> Any: input_name = self.session.get_inputs()[0].name output_name = self.session.get_outputs()[0].name - output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0] - return OnnxImage(output) + output = self.session.run([output_name], { + input_name: image.cpu().numpy() + })[0] + return OnnxTensor(output) def eval(self) -> None: pass From 338fc237c7c7c8043881472fc593930f8cb83166 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 08:41:15 -0600 Subject: [PATCH 2/3] fix(api): convert Real ESRGAN v3 using same arch as runtime --- api/onnx_web/chain/upscale_resrgan.py | 10 ++++---- api/onnx_web/convert/upscale_resrgan.py | 31 ++++++++++++++++++------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 0c60984a..b0e618f1 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -14,7 +14,7 @@ logger = getLogger(__name__) last_pipeline_instance = None last_pipeline_params = (None, None) -x4_v3_tag = "real-esrgan-x4-v3" +TAG_X4_V3 = "real-esrgan-x4-v3" def load_resrgan( @@ -37,7 +37,7 @@ def load_resrgan( if not path.isfile(model_path): raise Exception("Real ESRGAN model not found at %s" % model_path) - elif params.format == "onnx": + if params.format == "onnx": # use ONNX acceleration, if available model = OnnxNet( server, @@ -46,7 +46,7 @@ def load_resrgan( sess_options=device.sess_options(), ) elif params.format == "pth": - if x4_v3_tag in model_file: + if TAG_X4_V3 in model_file: # the x4-v3 model needs a different network model = SRVGGNetCompact( num_in_ch=3, @@ -69,8 +69,8 @@ def load_resrgan( raise Exception("unknown platform %s" % params.format) dni_weight = None - if params.upscale_model == x4_v3_tag and params.denoise != 1: - wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag)) + if params.upscale_model == TAG_X4_V3 and params.denoise != 1: + wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3) model_path = [model_path, wdn_model_path] dni_weight = [params.denoise, 1 - params.denoise] diff --git a/api/onnx_web/convert/upscale_resrgan.py b/api/onnx_web/convert/upscale_resrgan.py index 7e6f172d..1e3cb721 100644 --- a/api/onnx_web/convert/upscale_resrgan.py +++ b/api/onnx_web/convert/upscale_resrgan.py @@ -3,12 +3,15 @@ from os import path import torch from basicsr.archs.rrdbnet_arch import RRDBNet +from realesrgan.archs.srvgg_arch import SRVGGNetCompact from torch.onnx import export from .utils import ConversionContext, ModelDict logger = getLogger(__name__) +TAG_X4_V3 = "real-esrgan-x4-v3" + @torch.no_grad() def convert_upscale_resrgan( @@ -28,14 +31,26 @@ def convert_upscale_resrgan( return logger.info("loading and training model") - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=scale, - ) + + if TAG_X4_V3 in name: + # the x4-v3 model needs a different network + model = SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=scale, + act_type="prelu", + ) + else: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=scale, + ) torch_model = torch.load(source, map_location=ctx.map_location) if "params_ema" in torch_model: From b3b10b474649d6688c4aba8277373292c5090341 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 18 Feb 2023 08:46:46 -0600 Subject: [PATCH 3/3] apply lint --- api/onnx_web/chain/upscale_resrgan.py | 2 +- api/onnx_web/onnx/onnx_net.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index b0e618f1..4ee8438f 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -46,7 +46,7 @@ def load_resrgan( sess_options=device.sess_options(), ) elif params.format == "pth": - if TAG_X4_V3 in model_file: + if TAG_X4_V3 in model_file: # the x4-v3 model needs a different network model = SRVGGNetCompact( num_in_ch=3, diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index f6001f62..bf0cf524 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -57,9 +57,7 @@ class OnnxNet: def __call__(self, image: Any) -> Any: input_name = self.session.get_inputs()[0].name output_name = self.session.get_outputs()[0].name - output = self.session.run([output_name], { - input_name: image.cpu().numpy() - })[0] + output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0] return OnnxTensor(output) def eval(self) -> None: