1
0
Fork 0

switch RRDB nets based on upscaling

This commit is contained in:
Sean Sube 2023-12-30 13:28:16 -06:00
parent 0ddc16288f
commit 6834b716ea
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 14 additions and 13 deletions

View File

@ -4,7 +4,7 @@ from os import path
import torch
from torch.onnx import export
from ...models.rrdb import RRDBNet
from ...models.rrdb import RRDBNetRescale
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@ -27,7 +27,7 @@ def convert_correction_gfpgan(
logger.info("ONNX model already exists, skipping")
return
model = RRDBNet(
model = RRDBNetRescale(
num_in_ch=3,
num_out_ch=3,
num_feat=64,

View File

@ -4,7 +4,7 @@ from os import path
import torch
from torch.onnx import export
from ...models.rrdb import RRDBNet
from ...models.rrdb import RRDBNetRescale
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@ -28,7 +28,7 @@ def convert_upscaling_bsrgan(
return
# values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69
model = RRDBNet(
model = RRDBNetRescale(
num_in_ch=3,
num_out_ch=3,
num_feat=64,

View File

@ -87,6 +87,10 @@ def convert_upscale_resrgan(
else:
state_dict = torch_model
if any(["RDB" in key for key in state_dict.keys()]):
# keys need fixed up to match. capitalized RDB is the best indicator.
state_dict = fix_resrgan_keys(state_dict)
if TAG_X4_V3 in name:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
@ -97,10 +101,11 @@ def convert_upscale_resrgan(
upscale=scale,
act_type="prelu",
)
elif any(["RDB" in key for key in state_dict.keys()]):
# keys need fixed up to match. capitalized RDB is the best indicator.
state_dict = fix_resrgan_keys(state_dict)
model = RRDBNetFixed(
elif (
"conv_up1.weight" in state_dict.keys()
and "conv_up2.weight" in state_dict.keys()
):
model = RRDBNetRescale(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
@ -109,7 +114,7 @@ def convert_upscale_resrgan(
scale=scale,
)
else:
model = RRDBNetRescale(
model = RRDBNetFixed(
num_in_ch=3,
num_out_ch=3,
num_feat=64,

View File

@ -188,8 +188,6 @@ class RRDBNetFixed(nn.Module):
# upsampling
if self.scale > 1:
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
if self.scale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
@ -206,8 +204,6 @@ class RRDBNetFixed(nn.Module):
feat = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
)
if self.scale == 4:
feat = self.lrelu(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
)