better support for ESRGAN 1x models
This commit is contained in:
parent
698058018a
commit
477747cced
|
@ -22,6 +22,12 @@ SPECIAL_KEYS = {
|
||||||
"model.3.weight": "conv_up1.weight",
|
"model.3.weight": "conv_up1.weight",
|
||||||
"model.6.bias": "conv_up2.bias",
|
"model.6.bias": "conv_up2.bias",
|
||||||
"model.6.weight": "conv_up2.weight",
|
"model.6.weight": "conv_up2.weight",
|
||||||
|
# 1x model keys
|
||||||
|
"model.2.bias": "conv_hr.bias",
|
||||||
|
"model.2.weight": "conv_hr.weight",
|
||||||
|
"model.4.bias": "conv_last.bias",
|
||||||
|
"model.4.weight": "conv_last.weight",
|
||||||
|
# 2x and 4x model keys
|
||||||
"model.8.bias": "conv_hr.bias",
|
"model.8.bias": "conv_hr.bias",
|
||||||
"model.8.weight": "conv_hr.weight",
|
"model.8.weight": "conv_hr.weight",
|
||||||
"model.10.bias": "conv_last.bias",
|
"model.10.bias": "conv_last.bias",
|
||||||
|
|
|
@ -87,34 +87,46 @@ class RRDBNet(nn.Module):
|
||||||
scale=4,
|
scale=4,
|
||||||
):
|
):
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
if scale == 2:
|
||||||
|
num_in_ch = num_in_ch * 4
|
||||||
|
elif scale == 1:
|
||||||
|
num_in_ch = num_in_ch * 16
|
||||||
|
|
||||||
RRDB_block_f = functools.partial(RRDB, nf=num_feat, gc=num_grow_ch)
|
RRDB_block_f = functools.partial(RRDB, nf=num_feat, gc=num_grow_ch)
|
||||||
self.sf = scale
|
|
||||||
print([num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale])
|
print([num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale])
|
||||||
|
|
||||||
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
|
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
|
||||||
self.RRDB_trunk = make_layer(RRDB_block_f, num_block)
|
self.body = make_layer(RRDB_block_f, num_block)
|
||||||
self.trunk_conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
# upsampling
|
# upsampling
|
||||||
self.upconv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
if self.scale > 1:
|
||||||
if self.sf == 4:
|
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
self.upconv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
|
||||||
self.HRconv = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
if self.scale == 4:
|
||||||
|
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
fea = self.conv_first(x)
|
fea = self.conv_first(x)
|
||||||
trunk = self.trunk_conv(self.RRDB_trunk(fea))
|
trunk = self.conv_body(self.body(fea))
|
||||||
fea = fea + trunk
|
fea = fea + trunk
|
||||||
|
|
||||||
fea = self.lrelu(
|
if self.scale > 1:
|
||||||
self.upconv1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
|
||||||
)
|
|
||||||
if self.sf == 4:
|
|
||||||
fea = self.lrelu(
|
fea = self.lrelu(
|
||||||
self.upconv2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
self.conv_up1(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||||
)
|
)
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
|
||||||
|
if self.scale == 4:
|
||||||
|
fea = self.lrelu(
|
||||||
|
self.conv_up2(F.interpolate(fea, scale_factor=2, mode="nearest"))
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.conv_last(self.lrelu(self.conv_hr(fea)))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
Loading…
Reference in New Issue