fix(api): convert Real ESRGAN v3 using same arch as runtime
This commit is contained in:
parent
2c9d96d2ee
commit
338fc237c7
|
@ -14,7 +14,7 @@ logger = getLogger(__name__)
|
|||
last_pipeline_instance = None
|
||||
last_pipeline_params = (None, None)
|
||||
|
||||
x4_v3_tag = "real-esrgan-x4-v3"
|
||||
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||
|
||||
|
||||
def load_resrgan(
|
||||
|
@ -37,7 +37,7 @@ def load_resrgan(
|
|||
if not path.isfile(model_path):
|
||||
raise Exception("Real ESRGAN model not found at %s" % model_path)
|
||||
|
||||
elif params.format == "onnx":
|
||||
if params.format == "onnx":
|
||||
# use ONNX acceleration, if available
|
||||
model = OnnxNet(
|
||||
server,
|
||||
|
@ -46,7 +46,7 @@ def load_resrgan(
|
|||
sess_options=device.sess_options(),
|
||||
)
|
||||
elif params.format == "pth":
|
||||
if x4_v3_tag in model_file:
|
||||
if TAG_X4_V3 in model_file:
|
||||
# the x4-v3 model needs a different network
|
||||
model = SRVGGNetCompact(
|
||||
num_in_ch=3,
|
||||
|
@ -69,8 +69,8 @@ def load_resrgan(
|
|||
raise Exception("unknown platform %s" % params.format)
|
||||
|
||||
dni_weight = None
|
||||
if params.upscale_model == x4_v3_tag and params.denoise != 1:
|
||||
wdn_model_path = model_path.replace(x4_v3_tag, "%s-wdn" % (x4_v3_tag))
|
||||
if params.upscale_model == TAG_X4_V3 and params.denoise != 1:
|
||||
wdn_model_path = model_path.replace(TAG_X4_V3, "%s-wdn" % TAG_X4_V3)
|
||||
model_path = [model_path, wdn_model_path]
|
||||
dni_weight = [params.denoise, 1 - params.denoise]
|
||||
|
||||
|
|
|
@ -3,12 +3,15 @@ from os import path
|
|||
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from torch.onnx import export
|
||||
|
||||
from .utils import ConversionContext, ModelDict
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
TAG_X4_V3 = "real-esrgan-x4-v3"
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_upscale_resrgan(
|
||||
|
@ -28,14 +31,26 @@ def convert_upscale_resrgan(
|
|||
return
|
||||
|
||||
logger.info("loading and training model")
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
if TAG_X4_V3 in name:
|
||||
# the x4-v3 model needs a different network
|
||||
model = SRVGGNetCompact(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_conv=32,
|
||||
upscale=scale,
|
||||
act_type="prelu",
|
||||
)
|
||||
else:
|
||||
model = RRDBNet(
|
||||
num_in_ch=3,
|
||||
num_out_ch=3,
|
||||
num_feat=64,
|
||||
num_block=23,
|
||||
num_grow_ch=32,
|
||||
scale=scale,
|
||||
)
|
||||
|
||||
torch_model = torch.load(source, map_location=ctx.map_location)
|
||||
if "params_ema" in torch_model:
|
||||
|
|
Loading…
Reference in New Issue