better support for ESRGAN 1x models
This commit is contained in:
parent
698058018a
commit
477747cced
|
@ -22,6 +22,12 @@ SPECIAL_KEYS = {
|
|||
"model.3.weight": "conv_up1.weight",
|
||||
"model.6.bias": "conv_up2.bias",
|
||||
"model.6.weight": "conv_up2.weight",
|
||||
# 1x model keys
|
||||
"model.2.bias": "conv_hr.bias",
|
||||
"model.2.weight": "conv_hr.weight",
|
||||
"model.4.bias": "conv_last.bias",
|
||||
"model.4.weight": "conv_last.weight",
|
||||
# 2x and 4x model keys
|
||||
"model.8.bias": "conv_hr.bias",
|
||||
"model.8.weight": "conv_hr.weight",
|
||||
"model.10.bias": "conv_last.bias",
|
||||
|
|
|
@ -87,34 +87,46 @@ class RRDBNet(nn.Module):
|
|||
scale=4,
|
||||
):
|
||||
super(RRDBNet, self).__init__()
|
||||
self.scale = scale
|
||||
if scale == 2:
|
||||
num_in_ch = num_in_ch * 4
|
||||
elif scale == 1:
|
||||
num_in_ch = num_in_ch * 16
|
||||
|
||||
RRDB_block_f = functools.partial(RRDB, nf=num_feat, gc=num_grow_ch)
|
||||
self.sf = scale
|
||||
print([num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale])
|
||||
|
||||
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
|
||||
self.RRDB_trunk = make_layer(RRDB_block_f, num_block)
|
||||
self.trunk_conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
self.body = make_layer(RRDB_block_f, num_block)
|
||||
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
|
||||
# upsampling
|
||||
self.upconv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
if self.sf == 4:
|
||||
self.upconv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
self.HRconv = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
if self.scale > 1:
|
||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
|
||||
if self.scale == 4:
|
||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
|
||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
||||
trunk = self.conv_body(self.body(fea))
|
||||
fea = fea + trunk
|
||||
|
||||
fea = self.lrelu(
|
||||
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
if self.sf == 4:
|
||||
if self.scale > 1:
|
||||
fea = self.lrelu(
|
||||
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
self.conv_up1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
||||
if self.scale == 4:
|
||||
fea = self.lrelu(
|
||||
self.conv_up2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||
)
|
||||
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
|
||||
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue