1
0
Fork 0

better support for ESRGAN 1x models

This commit is contained in:
Sean Sube 2023-12-30 11:50:28 -06:00
parent 698058018a
commit 477747cced
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 32 additions and 14 deletions

View File

@ -22,6 +22,12 @@ SPECIAL_KEYS = {
"model.3.weight": "conv_up1.weight", "model.3.weight": "conv_up1.weight",
"model.6.bias": "conv_up2.bias", "model.6.bias": "conv_up2.bias",
"model.6.weight": "conv_up2.weight", "model.6.weight": "conv_up2.weight",
# 1x model keys
"model.2.bias": "conv_hr.bias",
"model.2.weight": "conv_hr.weight",
"model.4.bias": "conv_last.bias",
"model.4.weight": "conv_last.weight",
# 2x and 4x model keys
"model.8.bias": "conv_hr.bias", "model.8.bias": "conv_hr.bias",
"model.8.weight": "conv_hr.weight", "model.8.weight": "conv_hr.weight",
"model.10.bias": "conv_last.bias", "model.10.bias": "conv_last.bias",

View File

@ -87,34 +87,46 @@ class RRDBNet(nn.Module):
scale=4, scale=4,
): ):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
self.scale = scale
if scale == 2:
num_in_ch = num_in_ch * 4
elif scale == 1:
num_in_ch = num_in_ch * 16
RRDB_block_f = functools.partial(RRDB, nf=num_feat, gc=num_grow_ch) RRDB_block_f = functools.partial(RRDB, nf=num_feat, gc=num_grow_ch)
self.sf = scale
print([num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale]) print([num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale])
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, num_block) self.body = make_layer(RRDB_block_f, num_block)
self.trunk_conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
# upsampling # upsampling
self.upconv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) if self.scale > 1:
if self.sf == 4: self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) if self.scale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x): def forward(self, x):
fea = self.conv_first(x) fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea)) trunk = self.conv_body(self.body(fea))
fea = fea + trunk fea = fea + trunk
fea = self.lrelu( if self.scale > 1:
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
)
if self.sf == 4:
fea = self.lrelu( fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest")) self.conv_up1(F.interpolate(fea, scale_factor=2, mode="nearest"))
) )
out = self.conv_last(self.lrelu(self.HRconv(fea)))
if self.scale == 4:
fea = self.lrelu(
self.conv_up2(F.interpolate(fea, scale_factor=2, mode="nearest"))
)
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
return out return out