2023-02-11 04:41:24 +00:00
|
|
|
from logging import getLogger
|
|
|
|
from os import path
|
2023-12-27 11:17:17 +00:00
|
|
|
from re import compile
|
2023-02-11 04:41:24 +00:00
|
|
|
|
|
|
|
import torch
|
2023-02-09 04:35:54 +00:00
|
|
|
from torch.onnx import export
|
2023-02-11 04:41:24 +00:00
|
|
|
|
2023-12-30 19:11:50 +00:00
|
|
|
from ...models.rrdb import RRDBNetFixed, RRDBNetRescale
|
2023-04-11 13:26:21 +00:00
|
|
|
from ...models.srvgg import SRVGGNetCompact
|
2023-04-10 22:49:56 +00:00
|
|
|
from ..utils import ConversionContext, ModelDict
|
2023-02-09 04:35:54 +00:00
|
|
|
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
2023-02-18 14:41:15 +00:00
|
|
|
TAG_X4_V3 = "real-esrgan-x4-v3"
|
|
|
|
|
2023-12-27 11:08:15 +00:00
|
|
|
SPECIAL_KEYS = {
|
|
|
|
"model.0.bias": "conv_first.bias",
|
|
|
|
"model.0.weight": "conv_first.weight",
|
|
|
|
"model.1.sub.23.bias": "conv_body.bias",
|
|
|
|
"model.1.sub.23.weight": "conv_body.weight",
|
|
|
|
"model.3.bias": "conv_up1.bias",
|
|
|
|
"model.3.weight": "conv_up1.weight",
|
|
|
|
"model.6.bias": "conv_up2.bias",
|
|
|
|
"model.6.weight": "conv_up2.weight",
|
2023-12-30 17:50:28 +00:00
|
|
|
# 1x model keys
|
|
|
|
"model.2.bias": "conv_hr.bias",
|
|
|
|
"model.2.weight": "conv_hr.weight",
|
|
|
|
"model.4.bias": "conv_last.bias",
|
|
|
|
"model.4.weight": "conv_last.weight",
|
|
|
|
# 2x and 4x model keys
|
2023-12-27 11:08:15 +00:00
|
|
|
"model.8.bias": "conv_hr.bias",
|
|
|
|
"model.8.weight": "conv_hr.weight",
|
|
|
|
"model.10.bias": "conv_last.bias",
|
|
|
|
"model.10.weight": "conv_last.weight",
|
|
|
|
}
|
|
|
|
|
2023-12-27 14:47:06 +00:00
|
|
|
SUB_NAME = compile(r"^model\.1\.sub\.(\d+)\.RDB(\d)\.conv(\d)\.0\.(bias|weight)$")
|
2023-12-27 11:08:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
def fix_resrgan_keys(model):
|
|
|
|
original_keys = list(model.keys())
|
|
|
|
for key in original_keys:
|
|
|
|
if key in SPECIAL_KEYS:
|
|
|
|
new_key = SPECIAL_KEYS[key]
|
|
|
|
else:
|
|
|
|
# convert RDBN keys
|
2023-12-27 11:17:17 +00:00
|
|
|
matched = SUB_NAME.match(key)
|
|
|
|
if matched is not None:
|
|
|
|
sub_index, rdb_index, conv_index, node_type = matched.groups()
|
|
|
|
new_key = (
|
2023-12-27 14:47:06 +00:00
|
|
|
f"body.{sub_index}.rdb{rdb_index}.conv{conv_index}.{node_type}"
|
2023-12-27 11:17:17 +00:00
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError("unknown key format")
|
2023-12-27 11:08:15 +00:00
|
|
|
|
2023-12-27 14:47:06 +00:00
|
|
|
if new_key in model:
|
|
|
|
raise ValueError("key collision")
|
|
|
|
|
2023-12-27 11:08:15 +00:00
|
|
|
model[new_key] = model[key]
|
|
|
|
del model[key]
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
2023-02-09 04:35:54 +00:00
|
|
|
|
|
|
|
@torch.no_grad()
|
2023-02-11 04:41:24 +00:00
|
|
|
def convert_upscale_resrgan(
|
2023-04-10 01:33:03 +00:00
|
|
|
conversion: ConversionContext,
|
2023-02-11 04:41:24 +00:00
|
|
|
model: ModelDict,
|
|
|
|
source: str,
|
|
|
|
):
|
|
|
|
name = model.get("name")
|
|
|
|
source = source or model.get("source")
|
|
|
|
scale = model.get("scale")
|
|
|
|
|
2023-04-10 01:33:03 +00:00
|
|
|
dest = path.join(conversion.model_path, name + ".onnx")
|
2023-02-11 04:41:24 +00:00
|
|
|
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-02-11 04:41:24 +00:00
|
|
|
if path.isfile(dest):
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("ONNX model already exists, skipping")
|
2023-02-09 04:35:54 +00:00
|
|
|
return
|
|
|
|
|
2023-12-30 19:11:50 +00:00
|
|
|
torch_model = torch.load(source, map_location=conversion.map_location)
|
|
|
|
if "params_ema" in torch_model:
|
|
|
|
state_dict = torch_model["params_ema"]
|
|
|
|
elif "params" in torch_model:
|
|
|
|
state_dict = torch_model["params"]
|
|
|
|
else:
|
|
|
|
state_dict = torch_model
|
|
|
|
|
2023-12-30 19:28:16 +00:00
|
|
|
if any(["RDB" in key for key in state_dict.keys()]):
|
|
|
|
# keys need fixed up to match. capitalized RDB is the best indicator.
|
|
|
|
state_dict = fix_resrgan_keys(state_dict)
|
|
|
|
|
2023-02-18 14:41:15 +00:00
|
|
|
if TAG_X4_V3 in name:
|
|
|
|
# the x4-v3 model needs a different network
|
|
|
|
model = SRVGGNetCompact(
|
|
|
|
num_in_ch=3,
|
|
|
|
num_out_ch=3,
|
|
|
|
num_feat=64,
|
|
|
|
num_conv=32,
|
|
|
|
upscale=scale,
|
|
|
|
act_type="prelu",
|
|
|
|
)
|
2023-12-30 19:28:16 +00:00
|
|
|
elif (
|
|
|
|
"conv_up1.weight" in state_dict.keys()
|
|
|
|
and "conv_up2.weight" in state_dict.keys()
|
|
|
|
):
|
2023-12-30 19:50:14 +00:00
|
|
|
# both variants are the same for scale=4
|
2023-12-30 19:28:16 +00:00
|
|
|
model = RRDBNetRescale(
|
2023-02-18 14:41:15 +00:00
|
|
|
num_in_ch=3,
|
|
|
|
num_out_ch=3,
|
|
|
|
num_feat=64,
|
|
|
|
num_block=23,
|
|
|
|
num_grow_ch=32,
|
|
|
|
scale=scale,
|
|
|
|
)
|
2023-12-27 11:08:15 +00:00
|
|
|
else:
|
2023-12-30 19:28:16 +00:00
|
|
|
model = RRDBNetFixed(
|
2023-12-30 19:11:50 +00:00
|
|
|
num_in_ch=3,
|
|
|
|
num_out_ch=3,
|
|
|
|
num_feat=64,
|
|
|
|
num_block=23,
|
|
|
|
num_grow_ch=32,
|
|
|
|
scale=scale,
|
|
|
|
)
|
2023-02-09 04:35:54 +00:00
|
|
|
|
2023-12-30 19:11:50 +00:00
|
|
|
model.load_state_dict(state_dict, strict=True)
|
2023-04-10 01:33:03 +00:00
|
|
|
model.to(conversion.training_device).train(False)
|
2023-02-09 04:35:54 +00:00
|
|
|
model.eval()
|
|
|
|
|
2023-04-10 01:33:03 +00:00
|
|
|
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
2023-02-09 04:35:54 +00:00
|
|
|
input_names = ["data"]
|
|
|
|
output_names = ["output"]
|
|
|
|
dynamic_axes = {
|
|
|
|
"data": {2: "width", 3: "height"},
|
|
|
|
"output": {2: "width", 3: "height"},
|
|
|
|
}
|
|
|
|
|
2023-02-11 04:41:24 +00:00
|
|
|
logger.info("exporting ONNX model to %s", dest)
|
2023-02-09 04:35:54 +00:00
|
|
|
export(
|
|
|
|
model,
|
|
|
|
rng,
|
2023-02-11 04:41:24 +00:00
|
|
|
dest,
|
2023-02-09 04:35:54 +00:00
|
|
|
input_names=input_names,
|
|
|
|
output_names=output_names,
|
|
|
|
dynamic_axes=dynamic_axes,
|
2023-04-10 01:33:03 +00:00
|
|
|
opset_version=conversion.opset,
|
2023-02-09 04:35:54 +00:00
|
|
|
export_params=True,
|
|
|
|
)
|
2023-02-17 00:42:05 +00:00
|
|
|
logger.info("real ESRGAN exported to ONNX successfully")
|