From 477747cced21a456e0e00c57c7da5b447132e3ea Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 30 Dec 2023 11:50:28 -0600 Subject: [PATCH] better support for ESRGAN 1x models --- api/onnx_web/convert/upscaling/resrgan.py | 6 ++++ api/onnx_web/models/rrdb.py | 40 +++++++++++++++-------- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index 33a5d8d3..4b6e12fc 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -22,6 +22,12 @@ SPECIAL_KEYS = { "model.3.weight": "conv_up1.weight", "model.6.bias": "conv_up2.bias", "model.6.weight": "conv_up2.weight", + # 1x model keys + "model.2.bias": "conv_hr.bias", + "model.2.weight": "conv_hr.weight", + "model.4.bias": "conv_last.bias", + "model.4.weight": "conv_last.weight", + # 2x and 4x model keys "model.8.bias": "conv_hr.bias", "model.8.weight": "conv_hr.weight", "model.10.bias": "conv_last.bias", diff --git a/api/onnx_web/models/rrdb.py b/api/onnx_web/models/rrdb.py index dd5a14e4..54a43e2f 100644 --- a/api/onnx_web/models/rrdb.py +++ b/api/onnx_web/models/rrdb.py @@ -87,34 +87,46 @@ class RRDBNet(nn.Module): scale=4, ): super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + RRDB_block_f = functools.partial(RRDB, nf=num_feat, gc=num_grow_ch) - self.sf = scale print([num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale]) self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) - self.RRDB_trunk = make_layer(RRDB_block_f, num_block) - self.trunk_conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.body = make_layer(RRDB_block_f, num_block) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + # upsampling - self.upconv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) - if self.sf == 4: - self.upconv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) - self.HRconv = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + if self.scale > 1: + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + + if self.scale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): fea = self.conv_first(x) - trunk = self.trunk_conv(self.RRDB_trunk(fea)) + trunk = self.conv_body(self.body(fea)) fea = fea + trunk - fea = self.lrelu( - self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest")) - ) - if self.sf == 4: + if self.scale > 1: fea = self.lrelu( - self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) + self.conv_up1(F.interpolate(fea, scale_factor=2, mode="nearest")) ) - out = self.conv_last(self.lrelu(self.HRconv(fea))) + + if self.scale == 4: + fea = self.lrelu( + self.conv_up2(F.interpolate(fea, scale_factor=2, mode="nearest")) + ) + + out = self.conv_last(self.lrelu(self.conv_hr(fea))) return out