diff --git a/api/onnx_web/chain/upscale_resrgan.py b/api/onnx_web/chain/upscale_resrgan.py index 0c60984a..b0e618f1 100644 --- a/api/onnx_web/chain/upscale_resrgan.py +++ b/api/onnx_web/chain/upscale_resrgan.py @@ -14,7 +14,7 @@ logger = getLogger(__name__) last_pipeline_instance = None last_pipeline_params = (None, None) -x4_v3_tag = "real-esrgan-x4-v3" +TAG_X4_V3 = "real-esrgan-x4-v3" def load_resrgan( @@ -37,7 +37,7 @@ def load_resrgan( if not path.isfile(model_path): raise Exception("Real ESRGAN model not found at %s" % model_path) - elif params.format == "onnx": + if params.format == "onnx": # use ONNX acceleration, if available model = OnnxNet( server, @@ -46,7 +46,7 @@ def load_resrgan( sess_options=device.sess_options(), ) elif params.format == "pth": - if x4_v3_tag in model_file: + if TAG_X4_V3 in model_file: # the x4-v3 model needs a different network model = SRVGGNetCompact( num_in_ch=3, @@ -69,8 +69,8 @@ def load_resrgan( raise Exception("unknown platform %s" % params.format) dni_weight = None - if params.upscale_model == x4_v3_tag and params.denoise != 1: - wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag)) + if params.upscale_model == TAG_X4_V3 and params.denoise != 1: + wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3) model_path = [model_path, wdn_model_path] dni_weight = [params.denoise, 1 - params.denoise] diff --git a/api/onnx_web/convert/upscale_resrgan.py b/api/onnx_web/convert/upscale_resrgan.py index 7e6f172d..1e3cb721 100644 --- a/api/onnx_web/convert/upscale_resrgan.py +++ b/api/onnx_web/convert/upscale_resrgan.py @@ -3,12 +3,15 @@ from os import path import torch from basicsr.archs.rrdbnet_arch import RRDBNet +from realesrgan.archs.srvgg_arch import SRVGGNetCompact from torch.onnx import export from .utils import ConversionContext, ModelDict logger = getLogger(__name__) +TAG_X4_V3 = "real-esrgan-x4-v3" + @torch.no_grad() def convert_upscale_resrgan( @@ -28,14 +31,26 @@ def convert_upscale_resrgan( return logger.info("loading and training model") - model = RRDBNet( - num_in_ch=3, - num_out_ch=3, - num_feat=64, - num_block=23, - num_grow_ch=32, - scale=scale, - ) + + if TAG_X4_V3 in name: + # the x4-v3 model needs a different network + model = SRVGGNetCompact( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_conv=32, + upscale=scale, + act_type="prelu", + ) + else: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=scale, + ) torch_model = torch.load(source, map_location=ctx.map_location) if "params_ema" in torch_model: