1
0
Fork 0
onnx-web/api/onnx_web/convert/upscaling/swinir.py

116 lines
3.3 KiB
Python
Raw Permalink Normal View History

from logging import getLogger
from os import path
import torch
from torch.onnx import export
from ...models.swinir import SwinIR
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@torch.no_grad()
def convert_upscaling_swinir(
conversion: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale", 1)
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting SwinIR model: %s -> %s", name, dest)
if path.isfile(dest):
2023-02-17 00:42:05 +00:00
logger.info("ONNX model already exists, skipping")
return
# values based on https://github.com/JingyunLiang/SwinIR/blob/main/main_test_swinir.py#L128
params = {
2023-04-11 02:04:48 +00:00
"depths": [6] * 6,
"embed_dim": 180,
"img_range": 1.0,
"img_size": (64, 64),
"in_chans": 3,
2023-04-11 02:04:48 +00:00
"num_heads": [6] * 6,
"resi_connection": "1conv",
"upsampler": "pixelshuffle",
"window_size": 8,
}
if "lightweight" in name:
logger.debug("using SwinIR lightweight params")
2023-04-11 02:04:48 +00:00
params["depths"] = [6] * 4
params["embed_dim"] = 60
2023-04-11 02:04:48 +00:00
params["num_heads"] = [6] * 4
params["upsampler"] = "pixelshuffledirect"
elif "real" in name:
if "large" in name:
logger.debug("using SwinIR real large params")
params["depths"] = [6] * 9
params["embed_dim"] = 240
params["num_heads"] = [8] * 9
params["upsampler"] = "nearest+conv"
params["resi_connection"] = "3conv"
else:
logger.debug("using SwinIR real params")
params["upsampler"] = "nearest+conv"
elif "gray_dn" in name:
params["img_size"] = (128, 128)
params["in_chans"] = 1
params["upsampler"] = ""
elif "color_dn" in name:
params["img_size"] = (128, 128)
params["upsampler"] = ""
elif "gray_jpeg" in name:
params["img_range"] = 255.0
params["img_size"] = (126, 126)
params["in_chans"] = 1
params["upsampler"] = ""
params["window_size"] = 7
elif "color_jpeg" in name:
params["img_range"] = 255.0
params["img_size"] = (126, 126)
params["upsampler"] = ""
params["window_size"] = 7
model = SwinIR(
mlp_ratio=2,
upscale=scale,
**params,
)
torch_model = torch.load(source, map_location=conversion.map_location)
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
2023-06-17 02:53:49 +00:00
elif "params" in torch_model:
model.load_state_dict(torch_model["params"], strict=False)
2023-06-17 02:53:49 +00:00
else:
model.load_state_dict(torch_model, strict=False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {
"input": {2: "h", 3: "w"},
"output": {2: "h", 3: "w"},
}
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=conversion.opset,
export_params=True,
)
logger.info("SwinIR exported to ONNX successfully")