diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 484314ba..4ee8438f 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -14,7 +14,7 @@ logger = getLogger(__name__) last_pipeline_instance = None last_pipeline_params = (None, None) -x4_v3_tag = "real-esrgan-x4-v3" +TAG_X4_V3 = "real-esrgan-x4-v3" def load_resrgan( @@ -37,17 +37,7 @@ def load_resrgan( if not path.isfile(model_path): raise Exception("Real ESRGAN model not found at %s" % model_path) - if x4_v3_tag in model_file: - # the x4-v3 model needs a different network - model = SRVGGNetCompact( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_conv=32, - upscale=4, - act_type="prelu", - ) - elif params.format == "onnx": + if params.format == "onnx": # use ONNX acceleration, if available model = OnnxNet( server, @@ -56,20 +46,31 @@ def load_resrgan( sess_options=device.sess_options(), ) elif params.format == "pth": - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=params.scale, - ) + if TAG_X4_V3 in model_file: + # the x4-v3 model needs a different network + model = SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=4, + act_type="prelu", + ) + else: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=params.scale, + ) else: raise Exception("unknown platform %s" % params.format) dni_weight = None - if params.upscale_model == x4_v3_tag and params.denoise != 1: - wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag)) + if params.upscale_model == TAG_X4_V3 and params.denoise != 1: + wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3) model_path = [model_path, wdn_model_path] dni_weight = [params.denoise, 1 - params.denoise] diff --git a/api/onnx_web/convert/upscale_resrgan.py b/api/onnx_web/convert/upscale_resrgan.py index 7e6f172d..1e3cb721 100644 --- a/api/onnx_web/convert/upscale_resrgan.py +++ b/api/onnx_web/convert/upscale_resrgan.py @@ -3,12 +3,15 @@ from os import path import torch from basicsr.archs.rrdbnet_arch import RRDBNet +from realesrgan.archs.srvgg_arch import SRVGGNetCompact from torch.onnx import export from .utils import ConversionContext, ModelDict logger = getLogger(__name__) +TAG_X4_V3 = "real-esrgan-x4-v3" + @torch.no_grad() def convert_upscale_resrgan( @@ -28,14 +31,26 @@ def convert_upscale_resrgan( return logger.info("loading and training model") - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=scale, - ) + + if TAG_X4_V3 in name: + # the x4-v3 model needs a different network + model = SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=scale, + act_type="prelu", + ) + else: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=scale, + ) torch_model = torch.load(source, map_location=ctx.map_location) if "params_ema" in torch_model: diff --git a/api/onnx_web/onnx/__init__.py b/api/onnx_web/onnx/__init__.py index 9d30760b..8e58b60a 100644 --- a/api/onnx_web/onnx/__init__.py +++ b/api/onnx_web/onnx/__init__.py @@ -1 +1 @@ -from .onnx_net import OnnxImage, OnnxNet +from .onnx_net import OnnxTensor, OnnxNet diff --git a/api/onnx_web/onnx/onnx_net.py b/api/onnx_web/onnx/onnx_net.py index 51697380..bf0cf524 100644 --- a/api/onnx_web/onnx/onnx_net.py +++ b/api/onnx_web/onnx/onnx_net.py @@ -8,7 +8,7 @@ from onnxruntime import InferenceSession, SessionOptions from ..utils import ServerContext -class OnnxImage: +class OnnxTensor: def __init__(self, source) -> None: self.source = source self.data = self @@ -58,7 +58,7 @@ class OnnxNet: input_name = self.session.get_inputs()[0].name output_name = self.session.get_outputs()[0].name output = self.session.run([output_name], {input_name: image.cpu().numpy()})[0] - return OnnxImage(output) + return OnnxTensor(output) def eval(self) -> None: pass