From 54db63394f8ccc5925be494897524f81c48fd29b Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 30 Dec 2023 13:50:14 -0600 Subject: [PATCH] arch adjustments --- api/onnx_web/convert/upscaling/resrgan.py | 1 + api/onnx_web/models/rrdb.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index d8edc821..e4a159b0 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -105,6 +105,7 @@ def convert_upscale_resrgan( "conv_up1.weight" in state_dict.keys() and "conv_up2.weight" in state_dict.keys() ): + # both variants are the same for scale=4 model = RRDBNetRescale( num_in_ch=3, num_out_ch=3, diff --git a/api/onnx_web/models/rrdb.py b/api/onnx_web/models/rrdb.py index 601767fd..19249ef9 100644 --- a/api/onnx_web/models/rrdb.py +++ b/api/onnx_web/models/rrdb.py @@ -185,6 +185,8 @@ class RRDBNetFixed(nn.Module): # upsampling if self.scale > 1: self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + + if self.scale > 2: self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) @@ -201,6 +203,8 @@ class RRDBNetFixed(nn.Module): feat = self.lrelu( self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")) ) + + if self.scale > 2: feat = self.lrelu( self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")) )