1
0
Fork 0
onnx-web/api/onnx_web/convert/upscaling/resrgan.py

126 lines
3.5 KiB
Python

from logging import getLogger
from os import path
from re import compile
import torch
from torch.onnx import export
from ...models.srvgg import SRVGGNetCompact
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
TAG_X4_V3 = "real-esrgan-x4-v3"
SPECIAL_KEYS = {
"model.0.bias": "conv_first.bias",
"model.0.weight": "conv_first.weight",
"model.1.sub.23.bias": "conv_body.bias",
"model.1.sub.23.weight": "conv_body.weight",
"model.3.bias": "conv_up1.bias",
"model.3.weight": "conv_up1.weight",
"model.6.bias": "conv_up2.bias",
"model.6.weight": "conv_up2.weight",
"model.8.bias": "conv_hr.bias",
"model.8.weight": "conv_hr.weight",
"model.10.bias": "conv_last.bias",
"model.10.weight": "conv_last.weight",
}
SUB_NAME = compile(r"^model\.1\.sub\.(\d)+\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$")
def fix_resrgan_keys(model):
original_keys = list(model.keys())
for key in original_keys:
if key in SPECIAL_KEYS:
new_key = SPECIAL_KEYS[key]
else:
# convert RDBN keys
matched = SUB_NAME.match(key)
if matched is not None:
sub_index, rdb_index, conv_index, node_type = matched.groups()
new_key = (
f"model.1.sub.{sub_index}.rdb{rdb_index}.{conv_index}.{node_type}"
)
else:
raise ValueError("unknown key format")
model[new_key] = model[key]
del model[key]
return model
@torch.no_grad()
def convert_upscale_resrgan(
conversion: ConversionContext,
model: ModelDict,
source: str,
):
from basicsr.archs.rrdbnet_arch import RRDBNet
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
if path.isfile(dest):
logger.info("ONNX model already exists, skipping")
return
if TAG_X4_V3 in name:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=scale,
act_type="prelu",
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
elif "params" in torch_model:
model.load_state_dict(torch_model["params"], strict=False)
else:
# keys need fixed up to match
model.load_state_dict(fix_resrgan_keys(torch_model), strict=False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=conversion.opset,
export_params=True,
)
logger.info("real ESRGAN exported to ONNX successfully")