fix(api): copy checkpoints into correct location, handle more models
This commit is contained in:
parent
1db5ebec84
commit
353a65513f
|
@ -5,7 +5,7 @@ from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffu
|
|||
from onnx import load, save_model
|
||||
from os import mkdir, path, environ
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from shutil import copyfile, rmtree
|
||||
from sys import exit
|
||||
from torch.onnx import export
|
||||
from typing import Dict, List, Tuple
|
||||
|
@ -40,13 +40,13 @@ sources: Dict[str, List[Tuple[str, str]]] = {
|
|||
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
||||
path.join('..', 'models'))
|
||||
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
map_location = None if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
||||
dest_path = path.join(model_path, name)
|
||||
dest_path = path.join(model_path, name + '.pth')
|
||||
dest_onnx = path.join(model_path, name + '.onnx')
|
||||
print('converting Real ESRGAN model: %s -> %s' % (name, dest_path))
|
||||
print('converting Real ESRGAN model: %s -> %s' % (name, dest_onnx))
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
print('ONNX model already exists, skipping.')
|
||||
|
@ -54,13 +54,20 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
|||
|
||||
if not path.isfile(dest_path):
|
||||
print('PTH model not found, downloading...')
|
||||
dest_path = load_file_from_url(
|
||||
url=url, model_dir=path.join(dest_path, name), progress=True, file_name=None)
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + '.cache', progress=True, file_name=None)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
print('loading and training model')
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=scale)
|
||||
model.load_state_dict(torch.load(dest_path)['params_ema'])
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=map_location)
|
||||
if 'params_ema' in torch_model:
|
||||
model.load_state_dict(torch_model['params_ema'])
|
||||
else:
|
||||
model.load_state_dict(torch_model['params'], strict=False)
|
||||
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
|
@ -86,9 +93,9 @@ def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
|||
|
||||
@torch.no_grad()
|
||||
def convert_gfpgan(name: str, url: str, opset: int):
|
||||
dest_path = path.join(model_path, name)
|
||||
dest_path = path.join(model_path, name + '.pth')
|
||||
dest_onnx = path.join(model_path, name + '.onnx')
|
||||
print('converting GFPGAN model: %s -> %s' % (name, dest_path))
|
||||
print('converting GFPGAN model: %s -> %s' % (name, dest_onnx))
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
print('ONNX model already exists, skipping.')
|
||||
|
@ -96,15 +103,21 @@ def convert_gfpgan(name: str, url: str, opset: int):
|
|||
|
||||
if not path.isfile(dest_path):
|
||||
print('PTH model not found, downloading...')
|
||||
dest_path = load_file_from_url(
|
||||
url=url, model_dir=path.join(dest_path, name), progress=True, file_name=None)
|
||||
download_path = load_file_from_url(
|
||||
url=url, model_dir=dest_path + '.cache', progress=True, file_name=None)
|
||||
copyfile(download_path, dest_path)
|
||||
|
||||
print('loading and training model')
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=4)
|
||||
|
||||
torch_model = torch.load(dest_path, map_location=map_location)
|
||||
# TODO: make sure strict=False is safe here
|
||||
model.load_state_dict(torch.load(dest_path)['params_ema'], strict=False)
|
||||
if 'params_ema' in torch_model:
|
||||
model.load_state_dict(torch_model['params_ema'], strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch_model['params'], strict=False)
|
||||
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
|
@ -358,7 +371,7 @@ def main() -> int:
|
|||
description='convert checkpoint models to ONNX')
|
||||
|
||||
# model groups
|
||||
parser.add_argument('--diffusers', action='store_true', default=True)
|
||||
parser.add_argument('--diffusers', action='store_true', default=False)
|
||||
parser.add_argument('--gfpgan', action='store_true', default=False)
|
||||
parser.add_argument('--resrgan', action='store_true', default=False)
|
||||
parser.add_argument('--skip', nargs='*', type=str, default=[])
|
||||
|
|
Loading…
Reference in New Issue