1
0
Fork 0
onnx-web/api/onnx_web/convert/upscaling/resrgan.py

83 lines
2.1 KiB
Python
Raw Normal View History

from logging import getLogger
from os import path
import torch
from torch.onnx import export
from ...models.srvgg import SRVGGNetCompact
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
TAG_X4_V3 = "real-esrgan-x4-v3"
@torch.no_grad()
def convert_upscale_resrgan(
conversion: ConversionContext,
model: ModelDict,
source: str,
):
from basicsr.archs.rrdbnet_arch import RRDBNet
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale")
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting Real ESRGAN model: %s -> %s", name, dest)
if path.isfile(dest):
2023-02-17 00:42:05 +00:00
logger.info("ONNX model already exists, skipping")
return
if TAG_X4_V3 in name:
# the x4-v3 model needs a different network
model = SRVGGNetCompact(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_conv=32,
upscale=scale,
act_type="prelu",
)
else:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"])
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=conversion.opset,
export_params=True,
)
2023-02-17 00:42:05 +00:00
logger.info("real ESRGAN exported to ONNX successfully")