switch RRDB nets based on upscaling
This commit is contained in:
parent
0ddc16288f
commit
6834b716ea
|
@ -4,7 +4,7 @@ from os import path
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
from ...models.rrdb import RRDBNet
|
from ...models.rrdb import RRDBNetRescale
|
||||||
from ..utils import ConversionContext, ModelDict
|
from ..utils import ConversionContext, ModelDict
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -27,7 +27,7 @@ def convert_correction_gfpgan(
|
||||||
logger.info("ONNX model already exists, skipping")
|
logger.info("ONNX model already exists, skipping")
|
||||||
return
|
return
|
||||||
|
|
||||||
model = RRDBNet(
|
model = RRDBNetRescale(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
num_out_ch=3,
|
num_out_ch=3,
|
||||||
num_feat=64,
|
num_feat=64,
|
||||||
|
|
|
@ -4,7 +4,7 @@ from os import path
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
from ...models.rrdb import RRDBNet
|
from ...models.rrdb import RRDBNetRescale
|
||||||
from ..utils import ConversionContext, ModelDict
|
from ..utils import ConversionContext, ModelDict
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -28,7 +28,7 @@ def convert_upscaling_bsrgan(
|
||||||
return
|
return
|
||||||
|
|
||||||
# values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69
|
# values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69
|
||||||
model = RRDBNet(
|
model = RRDBNetRescale(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
num_out_ch=3,
|
num_out_ch=3,
|
||||||
num_feat=64,
|
num_feat=64,
|
||||||
|
|
|
@ -87,6 +87,10 @@ def convert_upscale_resrgan(
|
||||||
else:
|
else:
|
||||||
state_dict = torch_model
|
state_dict = torch_model
|
||||||
|
|
||||||
|
if any(["RDB" in key for key in state_dict.keys()]):
|
||||||
|
# keys need fixed up to match. capitalized RDB is the best indicator.
|
||||||
|
state_dict = fix_resrgan_keys(state_dict)
|
||||||
|
|
||||||
if TAG_X4_V3 in name:
|
if TAG_X4_V3 in name:
|
||||||
# the x4-v3 model needs a different network
|
# the x4-v3 model needs a different network
|
||||||
model = SRVGGNetCompact(
|
model = SRVGGNetCompact(
|
||||||
|
@ -97,10 +101,11 @@ def convert_upscale_resrgan(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
act_type="prelu",
|
act_type="prelu",
|
||||||
)
|
)
|
||||||
elif any(["RDB" in key for key in state_dict.keys()]):
|
elif (
|
||||||
# keys need fixed up to match. capitalized RDB is the best indicator.
|
"conv_up1.weight" in state_dict.keys()
|
||||||
state_dict = fix_resrgan_keys(state_dict)
|
and "conv_up2.weight" in state_dict.keys()
|
||||||
model = RRDBNetFixed(
|
):
|
||||||
|
model = RRDBNetRescale(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
num_out_ch=3,
|
num_out_ch=3,
|
||||||
num_feat=64,
|
num_feat=64,
|
||||||
|
@ -109,7 +114,7 @@ def convert_upscale_resrgan(
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = RRDBNetRescale(
|
model = RRDBNetFixed(
|
||||||
num_in_ch=3,
|
num_in_ch=3,
|
||||||
num_out_ch=3,
|
num_out_ch=3,
|
||||||
num_feat=64,
|
num_feat=64,
|
||||||
|
|
|
@ -188,8 +188,6 @@ class RRDBNetFixed(nn.Module):
|
||||||
# upsampling
|
# upsampling
|
||||||
if self.scale > 1:
|
if self.scale > 1:
|
||||||
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
if self.scale == 4:
|
|
||||||
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
||||||
|
@ -206,8 +204,6 @@ class RRDBNetFixed(nn.Module):
|
||||||
feat = self.lrelu(
|
feat = self.lrelu(
|
||||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
|
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.scale == 4:
|
|
||||||
feat = self.lrelu(
|
feat = self.lrelu(
|
||||||
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
|
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue