From 6834b716ea4edc294c10478f6c681af7c1794a18 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 30 Dec 2023 13:28:16 -0600 Subject: [PATCH] switch RRDB nets based on upscaling --- api/onnx_web/convert/correction/gfpgan.py | 4 ++-- api/onnx_web/convert/upscaling/bsrgan.py | 4 ++-- api/onnx_web/convert/upscaling/resrgan.py | 15 ++++++++++----- api/onnx_web/models/rrdb.py | 4 ---- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/api/onnx_web/convert/correction/gfpgan.py b/api/onnx_web/convert/correction/gfpgan.py index 5a5e9508..2203a575 100644 --- a/api/onnx_web/convert/correction/gfpgan.py +++ b/api/onnx_web/convert/correction/gfpgan.py @@ -4,7 +4,7 @@ from os import path import torch from torch.onnx import export -from ...models.rrdb import RRDBNet +from ...models.rrdb import RRDBNetRescale from ..utils import ConversionContext, ModelDict logger = getLogger(__name__) @@ -27,7 +27,7 @@ def convert_correction_gfpgan( logger.info("ONNX model already exists, skipping") return - model = RRDBNet( + model = RRDBNetRescale( num_in_ch=3, num_out_ch=3, num_feat=64, diff --git a/api/onnx_web/convert/upscaling/bsrgan.py b/api/onnx_web/convert/upscaling/bsrgan.py index 915e4b9d..114d1b55 100644 --- a/api/onnx_web/convert/upscaling/bsrgan.py +++ b/api/onnx_web/convert/upscaling/bsrgan.py @@ -4,7 +4,7 @@ from os import path import torch from torch.onnx import export -from ...models.rrdb import RRDBNet +from ...models.rrdb import RRDBNetRescale from ..utils import ConversionContext, ModelDict logger = getLogger(__name__) @@ -28,7 +28,7 @@ def convert_upscaling_bsrgan( return # values based on https://github.com/cszn/BSRGAN/blob/main/main_test_bsrgan.py#L69 - model = RRDBNet( + model = RRDBNetRescale( num_in_ch=3, num_out_ch=3, num_feat=64, diff --git a/api/onnx_web/convert/upscaling/resrgan.py b/api/onnx_web/convert/upscaling/resrgan.py index be695898..d8edc821 100644 --- a/api/onnx_web/convert/upscaling/resrgan.py +++ b/api/onnx_web/convert/upscaling/resrgan.py @@ -87,6 +87,10 @@ def convert_upscale_resrgan( else: state_dict = torch_model + if any(["RDB" in key for key in state_dict.keys()]): + # keys need fixed up to match. capitalized RDB is the best indicator. + state_dict = fix_resrgan_keys(state_dict) + if TAG_X4_V3 in name: # the x4-v3 model needs a different network model = SRVGGNetCompact( @@ -97,10 +101,11 @@ def convert_upscale_resrgan( upscale=scale, act_type="prelu", ) - elif any(["RDB" in key for key in state_dict.keys()]): - # keys need fixed up to match. capitalized RDB is the best indicator. - state_dict = fix_resrgan_keys(state_dict) - model = RRDBNetFixed( + elif ( + "conv_up1.weight" in state_dict.keys() + and "conv_up2.weight" in state_dict.keys() + ): + model = RRDBNetRescale( num_in_ch=3, num_out_ch=3, num_feat=64, @@ -109,7 +114,7 @@ def convert_upscale_resrgan( scale=scale, ) else: - model = RRDBNetRescale( + model = RRDBNetFixed( num_in_ch=3, num_out_ch=3, num_feat=64, diff --git a/api/onnx_web/models/rrdb.py b/api/onnx_web/models/rrdb.py index 214c063d..e2c77b42 100644 --- a/api/onnx_web/models/rrdb.py +++ b/api/onnx_web/models/rrdb.py @@ -188,8 +188,6 @@ class RRDBNetFixed(nn.Module): # upsampling if self.scale > 1: self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) - - if self.scale == 4: self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) @@ -206,8 +204,6 @@ class RRDBNetFixed(nn.Module): feat = self.lrelu( self.conv_up1(F.interpolate(feat, scale_factor=2, mode="nearest")) ) - - if self.scale == 4: feat = self.lrelu( self.conv_up2(F.interpolate(feat, scale_factor=2, mode="nearest")) )