115 lines
3.3 KiB
Python
115 lines
3.3 KiB
Python
from logging import getLogger
|
|
from os import path
|
|
|
|
import torch
|
|
from torch.onnx import export
|
|
|
|
from ...models.swinir import SwinIR
|
|
from ..utils import ConversionContext, ModelDict
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
@torch.no_grad()
|
|
def convert_upscaling_swinir(
|
|
conversion: ConversionContext,
|
|
model: ModelDict,
|
|
source: str,
|
|
):
|
|
name = model.get("name")
|
|
source = source or model.get("source")
|
|
scale = model.get("scale", 1)
|
|
|
|
dest = path.join(conversion.model_path, name + ".onnx")
|
|
logger.info("converting SwinIR model: %s -> %s", name, dest)
|
|
|
|
if path.isfile(dest):
|
|
logger.info("ONNX model already exists, skipping")
|
|
return
|
|
|
|
logger.info("loading and training model")
|
|
# values based on https://github.com/JingyunLiang/SwinIR/blob/main/main_test_swinir.py#L128
|
|
params = {
|
|
"depths": [6] * 6,
|
|
"embed_dim": 180,
|
|
"img_range": 1.0,
|
|
"img_size": (64, 64),
|
|
"in_chans": 3,
|
|
"num_heads": [6] * 6,
|
|
"resi_connection": "1conv",
|
|
"upsampler": "pixelshuffle",
|
|
"window_size": 8,
|
|
}
|
|
|
|
if "lightweight" in name:
|
|
logger.debug("using SwinIR lightweight params")
|
|
params["depths"] = [6] * 4
|
|
params["embed_dim"] = 60
|
|
params["num_heads"] = [6] * 4
|
|
params["upsampler"] = "pixelshuffledirect"
|
|
elif "real" in name:
|
|
if "large" in name:
|
|
logger.debug("using SwinIR real large params")
|
|
params["depths"] = [6] * 9
|
|
params["embed_dim"] = 240
|
|
params["num_heads"] = [8] * 9
|
|
params["upsampler"] = "nearest+conv"
|
|
params["resi_connection"] = "3conv"
|
|
else:
|
|
logger.debug("using SwinIR real params")
|
|
params["upsampler"] = "nearest+conv"
|
|
elif "gray_dn" in name:
|
|
params["img_size"] = (128, 128)
|
|
params["in_chans"] = 1
|
|
params["upsampler"] = ""
|
|
elif "color_dn" in name:
|
|
params["img_size"] = (128, 128)
|
|
params["upsampler"] = ""
|
|
elif "gray_jpeg" in name:
|
|
params["img_range"] = 255.0
|
|
params["img_size"] = (126, 126)
|
|
params["in_chans"] = 1
|
|
params["upsampler"] = ""
|
|
params["window_size"] = 7
|
|
elif "color_jpeg" in name:
|
|
params["img_range"] = 255.0
|
|
params["img_size"] = (126, 126)
|
|
params["upsampler"] = ""
|
|
params["window_size"] = 7
|
|
|
|
model = SwinIR(
|
|
mlp_ratio=2,
|
|
upscale=scale,
|
|
**params,
|
|
)
|
|
|
|
torch_model = torch.load(source, map_location=conversion.map_location)
|
|
if "params_ema" in torch_model:
|
|
model.load_state_dict(torch_model["params_ema"], strict=False)
|
|
else:
|
|
model.load_state_dict(torch_model["params"], strict=False)
|
|
|
|
model.to(conversion.training_device).train(False)
|
|
model.eval()
|
|
|
|
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
|
|
input_names = ["input"]
|
|
output_names = ["output"]
|
|
dynamic_axes = {
|
|
"input": {2: "h", 3: "w"},
|
|
"output": {2: "h", 3: "w"},
|
|
}
|
|
|
|
logger.info("exporting ONNX model to %s", dest)
|
|
export(
|
|
model,
|
|
rng,
|
|
dest,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
dynamic_axes=dynamic_axes,
|
|
opset_version=conversion.opset,
|
|
export_params=True,
|
|
)
|
|
logger.info("SwinIR exported to ONNX successfully")
|