1
0
Fork 0

feat(api): support both ESRGAN variants

This commit is contained in:
Sean Sube 2023-12-30 13:11:50 -06:00
parent 6cee411e43
commit 0ddc16288f
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 99 additions and 25 deletions

View File

@ -5,7 +5,7 @@ from re import compile
import torch
from torch.onnx import export
from ...models.rrdb import RRDBNet
from ...models.rrdb import RRDBNetFixed, RRDBNetRescale
from ...models.srvgg import SRVGGNetCompact
from ..utils import ConversionContext, ModelDict
@ -79,6 +79,14 @@ def convert_upscale_resrgan(
logger.info("ONNX model already exists, skipping")
return
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
state_dict = torch_model["params_ema"]
elif "params" in torch_model:
state_dict = torch_model["params"]
else:
state_dict = torch_model
if TAG_X4_V3 in name:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
@ -89,8 +97,19 @@ def convert_upscale_resrgan(
upscale=scale,
act_type="prelu",
)
elif any(["RDB" in key for key in state_dict.keys()]):
# keys need fixed up to match. capitalized RDB is the best indicator.
state_dict = fix_resrgan_keys(state_dict)
model = RRDBNetFixed(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
else:
model = RRDBNet(
model = RRDBNetRescale(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
@ -99,15 +118,7 @@ def convert_upscale_resrgan(
scale=scale,
)
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
elif "params" in torch_model:
model.load_state_dict(torch_model["params"], strict=False)
else:
# keys need fixed up to match
model.load_state_dict(fix_resrgan_keys(torch_model), strict=False)
model.load_state_dict(state_dict, strict=True)
model.to(conversion.training_device).train(False)
model.eval()

View File

@ -77,18 +77,24 @@ class RRDB(nn.Module):
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
self.rdb1 = ResidualDenseBlock_5C(nf, gc)
self.rdb2 = ResidualDenseBlock_5C(nf, gc)
self.rdb3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
out = self.rdb1(x)
out = self.rdb2(out)
out = self.rdb3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
class RRDBNetRescale(nn.Module):
"""
RRDBNet with variable input channels based on scale.
This is the format expected by the official models.
In this architecture, the modules stay the same but input channels change.
"""
def __init__(
self,
num_in_ch=3,
@ -98,7 +104,7 @@ class RRDBNet(nn.Module):
num_grow_ch=32,
scale=4,
):
super(RRDBNet, self).__init__()
super(RRDBNetRescale, self).__init__()
self.scale = scale
if scale == 2:
@ -107,7 +113,7 @@ class RRDBNet(nn.Module):
num_in_ch = num_in_ch * 16
logger.trace(
"RRDBNet params: %s",
"RRDBNetRescale params: %s",
[num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale],
)
@ -116,11 +122,8 @@ class RRDBNet(nn.Module):
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
# upsampling
if self.scale > 1:
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
if self.scale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
@ -152,3 +155,63 @@ class RRDBNet(nn.Module):
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out
class RRDBNetFixed(nn.Module):
"""
RRDBNet with fixed input channels regardless of scale.
This is the format expected by many third-party models.
In this architecture, the modules come and go based on scale, but the input channels stay the same.
"""
def __init__(
self,
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
):
super(RRDBNetFixed, self).__init__()
self.scale = scale
logger.trace(
"RRDBNetFixed params: %s",
[num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale],
)
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.body = make_layer(RRDB, num_block, nf=num_feat, gc=num_grow_ch)
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
# upsampling
if self.scale > 1:
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
if self.scale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
feat = self.conv_first(x)
trunk = self.conv_body(self.body(feat))
feat = feat + trunk
if self.scale > 1:
feat = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
)
if self.scale == 4:
feat = self.lrelu(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
)
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
return out