1
0
Fork 0

switch RRDB nets based on upscaling

This commit is contained in:
Sean Sube 2023-12-30 13:28:16 -06:00
parent 0ddc16288f
commit 6834b716ea
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
4 changed files with 14 additions and 13 deletions

View File

@ -4,7 +4,7 @@ from os import path
import torch import torch
from torch.onnx import export from torch.onnx import export
from ...models.rrdb import RRDBNet from ...models.rrdb import RRDBNetRescale
from ..utils import ConversionContext, ModelDict from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__) logger = getLogger(__name__)
@ -27,7 +27,7 @@ def convert_correction_gfpgan(
logger.info("ONNX model already exists, skipping") logger.info("ONNX model already exists, skipping")
return return
model = RRDBNet( model = RRDBNetRescale(
num_in_ch=3, num_in_ch=3,
num_out_ch=3, num_out_ch=3,
num_feat=64, num_feat=64,

View File

@ -4,7 +4,7 @@ from os import path
import torch import torch
from torch.onnx import export from torch.onnx import export
from ...models.rrdb import RRDBNet from ...models.rrdb import RRDBNetRescale
from ..utils import ConversionContext, ModelDict from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__) logger = getLogger(__name__)
@ -28,7 +28,7 @@ def convert_upscaling_bsrgan(
return return
# values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69 # values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69
model = RRDBNet( model = RRDBNetRescale(
num_in_ch=3, num_in_ch=3,
num_out_ch=3, num_out_ch=3,
num_feat=64, num_feat=64,

View File

@ -87,6 +87,10 @@ def convert_upscale_resrgan(
else: else:
state_dict = torch_model state_dict = torch_model
if any(["RDB" in key for key in state_dict.keys()]):
# keys need fixed up to match. capitalized RDB is the best indicator.
state_dict = fix_resrgan_keys(state_dict)
if TAG_X4_V3 in name: if TAG_X4_V3 in name:
# the x4-v3 model needs a different network # the x4-v3 model needs a different network
model = SRVGGNetCompact( model = SRVGGNetCompact(
@ -97,10 +101,11 @@ def convert_upscale_resrgan(
upscale=scale, upscale=scale,
act_type="prelu", act_type="prelu",
) )
elif any(["RDB" in key for key in state_dict.keys()]): elif (
# keys need fixed up to match. capitalized RDB is the best indicator. "conv_up1.weight" in state_dict.keys()
state_dict = fix_resrgan_keys(state_dict) and "conv_up2.weight" in state_dict.keys()
model = RRDBNetFixed( ):
model = RRDBNetRescale(
num_in_ch=3, num_in_ch=3,
num_out_ch=3, num_out_ch=3,
num_feat=64, num_feat=64,
@ -109,7 +114,7 @@ def convert_upscale_resrgan(
scale=scale, scale=scale,
) )
else: else:
model = RRDBNetRescale( model = RRDBNetFixed(
num_in_ch=3, num_in_ch=3,
num_out_ch=3, num_out_ch=3,
num_feat=64, num_feat=64,

View File

@ -188,8 +188,6 @@ class RRDBNetFixed(nn.Module):
# upsampling # upsampling
if self.scale > 1: if self.scale > 1:
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
if self.scale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
@ -206,8 +204,6 @@ class RRDBNetFixed(nn.Module):
feat = self.lrelu( feat = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")) self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest"))
) )
if self.scale == 4:
feat = self.lrelu( feat = self.lrelu(
self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")) self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest"))
) )