1
0
Fork 0
onnx-web/api/onnx_web/convert.py

91 lines
2.4 KiB
Python
Raw Normal View History

from argparse import ArgumentParser
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.download_util import load_file_from_url
from os import path, environ
from sys import exit
from torch.onnx import export
import torch
from .upscale import (
gfpgan_url,
resrgan_url,
resrgan_name,
)
model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))
@torch.no_grad()
def convert_real_esrgan():
dest_path = path.join(model_path, resrgan_name + '.pth')
print('converting Real ESRGAN into %s' % dest_path)
if not path.isfile(dest_path):
print('existing model not found, downloading...')
for url in resrgan_url:
dest_path = load_file_from_url(
url=url, model_dir=path.join(dest_path, resrgan_name), progress=True, file_name=None)
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)
print('loading and training Real ESRGAN model')
model.load_state_dict(torch.load(dest_path)['params_ema'])
model.train(False)
model.eval()
rng = torch.rand(1, 3, 64, 64)
input_names = ['data']
output_names = ['output']
dynamic_axes = {'data': {2: 'width', 3: 'height'},
'output': {2: 'width', 3: 'height'}}
with torch.no_grad():
dest_onnx = path.join(model_path, resrgan_name + '.onnx')
print('exporting Real ESRGAN model to %s' % dest_onnx)
export(
model,
rng,
dest_onnx,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
export_params=True
)
print('Real ESRGAN exported to ONNX')
def convert_gfpgan():
pass
def convert_diffuser():
pass
def main() -> int:
parser = ArgumentParser(
prog='onnx-web model converter',
description='convert checkpoint models to ONNX')
parser.add_argument('--diffusers', type=str, nargs='+',
help='models using the diffusers pipeline')
parser.add_argument('--gfpgan', action='store_true')
parser.add_argument('--resrgan', action='store_true')
args = parser.parse_args()
print(args)
for model in args.diffusers:
print('convert ' + model)
if args.resrgan:
convert_real_esrgan()
return 0
if __name__ == '__main__':
exit(main())