diff --git a/api/onnx_web/models/rrdb.py b/api/onnx_web/models/rrdb.py index e2c77b42..601767fd 100644 --- a/api/onnx_web/models/rrdb.py +++ b/api/onnx_web/models/rrdb.py @@ -142,15 +142,12 @@ class RRDBNetRescale(nn.Module): trunk = self.conv_body(self.body(feat)) feat = feat + trunk - if self.scale > 1: - feat = self.lrelu( - self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")) - ) - - if self.scale == 4: - feat = self.lrelu( - self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")) - ) + feat = self.lrelu( + self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")) + ) + feat = self.lrelu( + self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")) + ) out = self.conv_last(self.lrelu(self.conv_hr(feat)))