feat(api): support both ESRGAN variants
This commit is contained in:
parent
6cee411e43
commit
0ddc16288f
|
@ -5,7 +5,7 @@ from re import compile
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
from ...models.rrdb import RRDBNet
|
from ...models.rrdb import RRDBNetFixed, RRDBNetRescale
|
||||||
from ...models.srvgg import SRVGGNetCompact
|
from ...models.srvgg import SRVGGNetCompact
|
||||||
from ..utils import ConversionContext, ModelDict
|
from ..utils import ConversionContext, ModelDict
|
||||||
|
|
||||||
|
@ -79,6 +79,14 @@ def convert_upscale_resrgan(
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
torch_model = torch.load(source, map_location=conversion.map_location)
|
||||||
|
if "params_ema" in torch_model:
|
||||||
|
state_dict = torch_model["params_ema"]
|
||||||
|
elif "params" in torch_model:
|
||||||
|
state_dict = torch_model["params"]
|
||||||
|
else:
|
||||||
|
state_dict = torch_model
|
||||||
|
|
||||||
if TAG_X4_V3 in name:
|
if TAG_X4_V3 in name:
|
||||||
# the x4-v3 model needs a different network
|
# the x4-v3 model needs a different network
|
||||||
model = SRVGGNetCompact(
|
model = SRVGGNetCompact(
|
||||||
|
@ -89,8 +97,19 @@ def convert_upscale_resrgan(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
act_type="prelu",
|
act_type="prelu",
|
||||||
)
|
)
|
||||||
|
elif any(["RDB" in key for key in state_dict.keys()]):
|
||||||
|
# keys need fixed up to match. capitalized RDB is the best indicator.
|
||||||
|
state_dict = fix_resrgan_keys(state_dict)
|
||||||
|
model = RRDBNetFixed(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model = RRDBNet(
|
model = RRDBNetRescale(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
num_out_ch=3,
|
num_out_ch=3,
|
||||||
num_feat=64,
|
num_feat=64,
|
||||||
|
@ -99,15 +118,7 @@ def convert_upscale_resrgan(
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch_model = torch.load(source, map_location=conversion.map_location)
|
model.load_state_dict(state_dict, strict=True)
|
||||||
if "params_ema" in torch_model:
|
|
||||||
model.load_state_dict(torch_model["params_ema"])
|
|
||||||
elif "params" in torch_model:
|
|
||||||
model.load_state_dict(torch_model["params"], strict=False)
|
|
||||||
else:
|
|
||||||
# keys need fixed up to match
|
|
||||||
model.load_state_dict(fix_resrgan_keys(torch_model), strict=False)
|
|
||||||
|
|
||||||
model.to(conversion.training_device).train(False)
|
model.to(conversion.training_device).train(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|
|
@ -77,18 +77,24 @@ class RRDB(nn.Module):
|
||||||
|
|
||||||
def __init__(self, nf, gc=32):
|
def __init__(self, nf, gc=32):
|
||||||
super(RRDB, self).__init__()
|
super(RRDB, self).__init__()
|
||||||
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
|
self.rdb1 = ResidualDenseBlock_5C(nf, gc)
|
||||||
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
|
self.rdb2 = ResidualDenseBlock_5C(nf, gc)
|
||||||
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
|
self.rdb3 = ResidualDenseBlock_5C(nf, gc)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out = self.RDB1(x)
|
out = self.rdb1(x)
|
||||||
out = self.RDB2(out)
|
out = self.rdb2(out)
|
||||||
out = self.RDB3(out)
|
out = self.rdb3(out)
|
||||||
return out * 0.2 + x
|
return out * 0.2 + x
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class RRDBNetRescale(nn.Module):
|
||||||
|
"""
|
||||||
|
RRDBNet with variable input channels based on scale.
|
||||||
|
This is the format expected by the official models.
|
||||||
|
In this architecture, the modules stay the same but input channels change.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
|
@ -98,7 +104,7 @@ class RRDBNet(nn.Module):
|
||||||
num_grow_ch=32,
|
num_grow_ch=32,
|
||||||
scale=4,
|
scale=4,
|
||||||
):
|
):
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNetRescale, self).__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
if scale == 2:
|
if scale == 2:
|
||||||
|
@ -107,7 +113,7 @@ class RRDBNet(nn.Module):
|
||||||
num_in_ch = num_in_ch * 16
|
num_in_ch = num_in_ch * 16
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"RRDBNet params: %s",
|
"RRDBNetRescale params: %s",
|
||||||
[num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale],
|
[num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -116,10 +122,7 @@ class RRDBNet(nn.Module):
|
||||||
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
# upsampling
|
# upsampling
|
||||||
if self.scale > 1:
|
|
||||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
if self.scale == 4:
|
|
||||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
@ -152,3 +155,63 @@ class RRDBNet(nn.Module):
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class RRDBNetFixed(nn.Module):
|
||||||
|
"""
|
||||||
|
RRDBNet with fixed input channels regardless of scale.
|
||||||
|
This is the format expected by many third-party models.
|
||||||
|
In this architecture, the modules come and go based on scale, but the input channels stay the same.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=4,
|
||||||
|
):
|
||||||
|
super(RRDBNetFixed, self).__init__()
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
logger.trace(
|
||||||
|
"RRDBNetFixed params: %s",
|
||||||
|
[num_in_ch, num_out_ch, num_feat, num_block, num_grow_ch, scale],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
|
||||||
|
self.body = make_layer(RRDB, num_block, nf=num_feat, gc=num_grow_ch)
|
||||||
|
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
if self.scale > 1:
|
||||||
|
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
if self.scale == 4:
|
||||||
|
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
feat = self.conv_first(x)
|
||||||
|
trunk = self.conv_body(self.body(feat))
|
||||||
|
feat = feat + trunk
|
||||||
|
|
||||||
|
if self.scale > 1:
|
||||||
|
feat = self.lrelu(
|
||||||
|
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.scale == 4:
|
||||||
|
feat = self.lrelu(
|
||||||
|
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
|
||||||
|
)
|
||||||
|
|
||||||
|
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
Loading…
Reference in New Issue