1
0
Fork 0
onnx-web/api/onnx_web/convert/correction/gfpgan.py

70 lines
1.8 KiB
Python
Raw Normal View History

from logging import getLogger
from os import path
import torch
from torch.onnx import export
from ...models.rrdb import RRDBNet
from ..utils import ConversionContext, ModelDict
logger = getLogger(__name__)
@torch.no_grad()
def convert_correction_gfpgan(
conversion: ConversionContext,
model: ModelDict,
source: str,
):
name = model.get("name")
source = source or model.get("source")
scale = model.get("scale", 1)
dest = path.join(conversion.model_path, name + ".onnx")
logger.info("converting GFPGAN model: %s -> %s", name, dest)
if path.isfile(dest):
2023-02-17 00:42:05 +00:00
logger.info("ONNX model already exists, skipping")
return
logger.info("loading and training model")
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=scale,
)
torch_model = torch.load(source, map_location=conversion.map_location)
# TODO: make sure strict=False is safe here
if "params_ema" in torch_model:
model.load_state_dict(torch_model["params_ema"], strict=False)
else:
model.load_state_dict(torch_model["params"], strict=False)
model.to(conversion.training_device).train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64, device=conversion.map_location)
input_names = ["data"]
output_names = ["output"]
dynamic_axes = {
"data": {2: "width", 3: "height"},
"output": {2: "width", 3: "height"},
}
logger.info("exporting ONNX model to %s", dest)
export(
model,
rng,
dest,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=conversion.opset,
export_params=True,
)
2023-02-17 00:42:05 +00:00
logger.info("GFPGAN exported to ONNX successfully")