diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index 4b6e12fc..be695898 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -5,7 +5,7 @@ from re import compile import torch from torch.onnx import export -from ...models.rrdb import RRDBNet +from ...models.rrdb import RRDBNetFixed, RRDBNetRescale from ...models.srvgg import SRVGGNetCompact from ..utils import ConversionContext, ModelDict @@ -79,6 +79,14 @@ def convert_upscale_resrgan( logger.info("ONNX model already exists, skipping") return + torch_model = torch.load(source, map_location=conversion.map_location) + if "params_ema" in torch_model: + state_dict = torch_model["params_ema"] + elif "params" in torch_model: + state_dict = torch_model["params"] + else: + state_dict = torch_model + if TAG_X4_V3 in name: # the x4-v3 model needs a different network model = SRVGGNetCompact( @@ -89,8 +97,19 @@ def convert_upscale_resrgan( upscale=scale, act_type="prelu", ) + elif any(["RDB" in key for key in state_dict.keys()]): + # keys need fixed up to match. capitalized RDB is the best indicator. + state_dict = fix_resrgan_keys(state_dict) + model = RRDBNetFixed( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=scale, + ) else: - model = RRDBNet( + model = RRDBNetRescale( num_in_ch=3, num_out_ch=3, num_feat=64, @@ -99,15 +118,7 @@ def convert_upscale_resrgan( scale=scale, ) - torch_model = torch.load(source, map_location=conversion.map_location) - if "params_ema" in torch_model: - model.load_state_dict(torch_model["params_ema"]) - elif "params" in torch_model: - model.load_state_dict(torch_model["params"], strict=False) - else: - # keys need fixed up to match - model.load_state_dict(fix_resrgan_keys(torch_model), strict=False) - + model.load_state_dict(state_dict, strict=True) model.to(conversion.training_device).train(False) model.eval() diff --git a/api/onnx_web/models/rrdb.py b/api/onnx_web/models/rrdb.py index 496a9e50..214c063d 100644 --- a/api/onnx_web/models/rrdb.py +++ b/api/onnx_web/models/rrdb.py @@ -77,18 +77,24 @@ class RRDB(nn.Module): def __init__(self, nf, gc=32): super(RRDB, self).__init__() - self.RDB1 = ResidualDenseBlock_5C(nf, gc) - self.RDB2 = ResidualDenseBlock_5C(nf, gc) - self.RDB3 = ResidualDenseBlock_5C(nf, gc) + self.rdb1 = ResidualDenseBlock_5C(nf, gc) + self.rdb2 = ResidualDenseBlock_5C(nf, gc) + self.rdb3 = ResidualDenseBlock_5C(nf, gc) def forward(self, x): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) return out * 0.2 + x -class RRDBNet(nn.Module): +class RRDBNetRescale(nn.Module): + """ + RRDBNet with variable input channels based on scale. + This is the format expected by the official models. + In this architecture, the modules stay the same but input channels change. + """ + def __init__( self, num_in_ch=3, @@ -98,7 +104,7 @@ class RRDBNet(nn.Module): num_grow_ch=32, scale=4, ): - super(RRDBNet, self).__init__() + super(RRDBNetRescale, self).__init__() self.scale = scale if scale == 2: @@ -107,7 +113,7 @@ class RRDBNet(nn.Module): num_in_ch = num_in_ch * 16 logger.trace( - "RRDBNet params: %s", + "RRDBNetRescale params: %s", [num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale], ) @@ -116,11 +122,8 @@ class RRDBNet(nn.Module): self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) # upsampling - if self.scale > 1: - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) - - if self.scale == 4: - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True) @@ -152,3 +155,63 @@ class RRDBNet(nn.Module): out = self.conv_last(self.lrelu(self.conv_hr(feat))) return out + + +class RRDBNetFixed(nn.Module): + """ + RRDBNet with fixed input channels regardless of scale. + This is the format expected by many third-party models. + In this architecture, the modules come and go based on scale, but the input channels stay the same. + """ + + def __init__( + self, + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ): + super(RRDBNetFixed, self).__init__() + self.scale = scale + + logger.trace( + "RRDBNetFixed params: %s", + [num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale], + ) + + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) + self.body = make_layer(RRDB, num_block, nf=num_feat, gc=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + + # upsampling + if self.scale > 1: + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + + if self.scale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + feat = self.conv_first(x) + trunk = self.conv_body(self.body(feat)) + feat = feat + trunk + + if self.scale > 1: + feat = self.lrelu( + self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")) + ) + + if self.scale == 4: + feat = self.lrelu( + self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")) + ) + + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + + return out