diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index d8edc821..e4a159b0 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -105,6 +105,7 @@ def convert_upscale_resrgan( "conv_up1.weight" in state_dict.keys() and "conv_up2.weight" in state_dict.keys() ): + # both variants are the same for scale=4 model = RRDBNetRescale( num_in_ch=3, num_out_ch=3, diff --git a/api/onnx_web/models/rrdb.py b/api/onnx_web/models/rrdb.py index 601767fd..19249ef9 100644 --- a/api/onnx_web/models/rrdb.py +++ b/api/onnx_web/models/rrdb.py @@ -185,6 +185,8 @@ class RRDBNetFixed(nn.Module): # upsampling if self.scale > 1: self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + + if self.scale > 2: self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) @@ -201,6 +203,8 @@ class RRDBNetFixed(nn.Module): feat = self.lrelu( self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")) ) + + if self.scale > 2: feat = self.lrelu( self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")) )