1
0
Fork 0

fix(api): copy checkpoints into correct location, handle more models

This commit is contained in:
Sean Sube 2023-01-16 20:53:12 -06:00
parent 1db5ebec84
commit 353a65513f
1 changed files with 26 additions and 13 deletions

View File

@ -5,7 +5,7 @@ from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffu
from onnx import load, save_model
from os import mkdir, path, environ
from pathlib import Path
from shutil import rmtree
from shutil import copyfile, rmtree
from sys import exit
from torch.onnx import export
from typing import Dict, List, Tuple
@ -40,13 +40,13 @@ sources: Dict[str, List[Tuple[str, str]]] = {
model_path = environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models'))
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
map_location = None if torch.cuda.is_available() else torch.device('cpu')
@torch.no_grad()
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
dest_path = path.join(model_path, name)
dest_path = path.join(model_path, name + '.pth')
dest_onnx = path.join(model_path, name + '.onnx')
print('converting Real ESRGAN model: %s -> %s' % (name, dest_path))
print('converting Real ESRGAN model: %s -> %s' % (name, dest_onnx))
if path.isfile(dest_onnx):
print('ONNX model already exists, skipping.')
@ -54,13 +54,20 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
if not path.isfile(dest_path):
print('PTH model not found, downloading...')
dest_path = load_file_from_url(
url=url, model_dir=path.join(dest_path, name), progress=True, file_name=None)
download_path = load_file_from_url(
url=url, model_dir=dest_path + '.cache', progress=True, file_name=None)
copyfile(download_path, dest_path)
print('loading and training model')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=scale)
model.load_state_dict(torch.load(dest_path)['params_ema'])
torch_model = torch.load(dest_path, map_location=map_location)
if 'params_ema' in torch_model:
model.load_state_dict(torch_model['params_ema'])
else:
model.load_state_dict(torch_model['params'], strict=False)
model.to(training_device).train(False)
model.eval()
@ -86,9 +93,9 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
@torch.no_grad()
def convert_gfpgan(name: str, url: str, opset: int):
dest_path = path.join(model_path, name)
dest_path = path.join(model_path, name + '.pth')
dest_onnx = path.join(model_path, name + '.onnx')
print('converting GFPGAN model: %s -> %s' % (name, dest_path))
print('converting GFPGAN model: %s -> %s' % (name, dest_onnx))
if path.isfile(dest_onnx):
print('ONNX model already exists, skipping.')
@ -96,15 +103,21 @@ def convert_gfpgan(name: str, url: str, opset: int):
if not path.isfile(dest_path):
print('PTH model not found, downloading...')
dest_path = load_file_from_url(
url=url, model_dir=path.join(dest_path, name), progress=True, file_name=None)
download_path = load_file_from_url(
url=url, model_dir=dest_path + '.cache', progress=True, file_name=None)
copyfile(download_path, dest_path)
print('loading and training model')
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
num_block=23, num_grow_ch=32, scale=4)
torch_model = torch.load(dest_path, map_location=map_location)
# TODO: make sure strict=False is safe here
model.load_state_dict(torch.load(dest_path)['params_ema'], strict=False)
if 'params_ema' in torch_model:
model.load_state_dict(torch_model['params_ema'], strict=False)
else:
model.load_state_dict(torch_model['params'], strict=False)
model.to(training_device).train(False)
model.eval()
@ -358,7 +371,7 @@ def main() -> int:
description='convert checkpoint models to ONNX')
# model groups
parser.add_argument('--diffusers', action='store_true', default=True)
parser.add_argument('--diffusers', action='store_true', default=False)
parser.add_argument('--gfpgan', action='store_true', default=False)
parser.add_argument('--resrgan', action='store_true', default=False)
parser.add_argument('--skip', nargs='*', type=str, default=[])