2023-01-28 23:09:19 +00:00
|
|
|
from . import logging
|
2023-01-16 15:57:59 +00:00
|
|
|
from argparse import ArgumentParser
|
|
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
|
|
from basicsr.utils.download_util import load_file_from_url
|
2023-01-28 23:09:19 +00:00
|
|
|
from diffusers import (
|
|
|
|
OnnxRuntimeModel,
|
|
|
|
OnnxStableDiffusionPipeline,
|
|
|
|
StableDiffusionPipeline,
|
2023-01-29 21:23:01 +00:00
|
|
|
StableDiffusionUpscalePipeline,
|
2023-01-28 23:09:19 +00:00
|
|
|
)
|
2023-02-02 14:59:47 +00:00
|
|
|
from json import loads
|
2023-01-28 23:09:19 +00:00
|
|
|
from logging import getLogger
|
2023-01-16 23:48:50 +00:00
|
|
|
from onnx import load, save_model
|
2023-01-31 23:09:23 +00:00
|
|
|
from os import environ, makedirs, mkdir, path
|
2023-01-16 23:48:50 +00:00
|
|
|
from pathlib import Path
|
2023-01-17 02:53:12 +00:00
|
|
|
from shutil import copyfile, rmtree
|
2023-01-16 15:57:59 +00:00
|
|
|
from sys import exit
|
2023-01-16 20:17:50 +00:00
|
|
|
from torch.onnx import export
|
2023-01-27 23:08:36 +00:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
2023-01-16 15:57:59 +00:00
|
|
|
|
|
|
|
import torch
|
2023-01-22 21:45:09 +00:00
|
|
|
import warnings
|
|
|
|
|
|
|
|
# suppress common but harmless warnings, https://github.com/ssube/onnx-web/issues/75
|
2023-01-28 23:09:19 +00:00
|
|
|
warnings.filterwarnings(
|
|
|
|
'ignore', '.*The shape inference of prim::Constant type is missing.*')
|
2023-01-22 21:45:09 +00:00
|
|
|
warnings.filterwarnings('ignore', '.*Only steps=1 can be constant folded.*')
|
2023-01-28 23:09:19 +00:00
|
|
|
warnings.filterwarnings(
|
|
|
|
'ignore', '.*Converting a tensor to a Python boolean might cause the trace to be incorrect.*')
|
2023-01-16 15:57:59 +00:00
|
|
|
|
2023-01-27 23:08:36 +00:00
|
|
|
Models = Dict[str, List[Tuple[str, str, Optional[int]]]]
|
2023-01-21 14:52:35 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
|
|
|
2023-01-21 14:52:35 +00:00
|
|
|
# recommended models
|
|
|
|
base_models: Models = {
|
|
|
|
'diffusion': [
|
2023-01-16 23:48:50 +00:00
|
|
|
# v1.x
|
|
|
|
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
|
2023-01-17 02:10:29 +00:00
|
|
|
('stable-diffusion-onnx-v1-inpainting',
|
|
|
|
'runwayml/stable-diffusion-inpainting'),
|
2023-01-16 23:48:50 +00:00
|
|
|
# v2.x
|
|
|
|
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
|
2023-01-17 02:10:29 +00:00
|
|
|
('stable-diffusion-onnx-v2-inpainting',
|
|
|
|
'stabilityai/stable-diffusion-2-inpainting'),
|
2023-02-03 05:34:02 +00:00
|
|
|
# TODO: should have its own converter
|
2023-01-26 03:04:00 +00:00
|
|
|
('upscaling-stable-diffusion-x4', 'stabilityai/stable-diffusion-x4-upscaler'),
|
2023-01-16 23:08:59 +00:00
|
|
|
],
|
2023-01-21 14:52:35 +00:00
|
|
|
'correction': [
|
2023-01-17 03:55:54 +00:00
|
|
|
('correction-gfpgan-v1-3',
|
2023-01-17 03:36:50 +00:00
|
|
|
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', 4),
|
2023-01-16 23:08:59 +00:00
|
|
|
],
|
2023-01-21 14:52:35 +00:00
|
|
|
'upscaling': [
|
2023-01-17 02:38:16 +00:00
|
|
|
('upscaling-real-esrgan-x2-plus',
|
2023-01-17 02:40:01 +00:00
|
|
|
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', 2),
|
2023-01-17 02:10:29 +00:00
|
|
|
('upscaling-real-esrgan-x4-plus',
|
|
|
|
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', 4),
|
2023-01-17 02:38:16 +00:00
|
|
|
('upscaling-real-esrgan-x4-v3',
|
|
|
|
'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', 4),
|
2023-01-16 23:08:59 +00:00
|
|
|
],
|
|
|
|
}
|
2023-01-16 15:57:59 +00:00
|
|
|
|
|
|
|
model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
2023-01-16 23:08:59 +00:00
|
|
|
path.join('..', 'models'))
|
2023-01-16 23:48:50 +00:00
|
|
|
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
2023-01-21 23:55:55 +00:00
|
|
|
map_location = torch.device(training_device)
|
2023-01-16 23:48:50 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
|
2023-01-16 20:17:50 +00:00
|
|
|
@torch.no_grad()
|
2023-01-17 02:10:29 +00:00
|
|
|
def convert_real_esrgan(name: str, url: str, scale: int, opset: int):
|
2023-01-17 02:53:12 +00:00
|
|
|
dest_path = path.join(model_path, name + '.pth')
|
2023-01-16 23:08:59 +00:00
|
|
|
dest_onnx = path.join(model_path, name + '.onnx')
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('converting Real ESRGAN model: %s -> %s', name, dest_onnx)
|
2023-01-16 15:57:59 +00:00
|
|
|
|
2023-01-16 23:08:59 +00:00
|
|
|
if path.isfile(dest_onnx):
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('ONNX model already exists, skipping.')
|
2023-01-16 23:08:59 +00:00
|
|
|
return
|
|
|
|
|
|
|
|
if not path.isfile(dest_path):
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('PTH model not found, downloading...')
|
2023-01-17 02:53:12 +00:00
|
|
|
download_path = load_file_from_url(
|
2023-01-17 03:55:54 +00:00
|
|
|
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None)
|
2023-01-17 02:53:12 +00:00
|
|
|
copyfile(download_path, dest_path)
|
2023-01-16 23:08:59 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('loading and training model')
|
2023-01-16 23:08:59 +00:00
|
|
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
2023-01-17 02:10:29 +00:00
|
|
|
num_block=23, num_grow_ch=32, scale=scale)
|
2023-01-17 02:53:12 +00:00
|
|
|
|
|
|
|
torch_model = torch.load(dest_path, map_location=map_location)
|
|
|
|
if 'params_ema' in torch_model:
|
|
|
|
model.load_state_dict(torch_model['params_ema'])
|
|
|
|
else:
|
|
|
|
model.load_state_dict(torch_model['params'], strict=False)
|
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
model.to(training_device).train(False)
|
2023-01-16 23:08:59 +00:00
|
|
|
model.eval()
|
|
|
|
|
2023-01-22 01:59:58 +00:00
|
|
|
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
2023-01-16 23:08:59 +00:00
|
|
|
input_names = ['data']
|
|
|
|
output_names = ['output']
|
|
|
|
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
|
|
|
'output': {2: 'width', 3: 'height'}}
|
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('exporting ONNX model to %s', dest_onnx)
|
2023-01-16 23:08:59 +00:00
|
|
|
export(
|
|
|
|
model,
|
|
|
|
rng,
|
|
|
|
dest_onnx,
|
|
|
|
input_names=input_names,
|
|
|
|
output_names=output_names,
|
|
|
|
dynamic_axes=dynamic_axes,
|
2023-01-16 23:48:50 +00:00
|
|
|
opset_version=opset,
|
2023-01-16 23:08:59 +00:00
|
|
|
export_params=True
|
|
|
|
)
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('Real ESRGAN exported to ONNX successfully.')
|
2023-01-16 23:08:59 +00:00
|
|
|
|
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
@torch.no_grad()
|
2023-01-17 03:36:50 +00:00
|
|
|
def convert_gfpgan(name: str, url: str, scale: int, opset: int):
|
2023-01-17 02:53:12 +00:00
|
|
|
dest_path = path.join(model_path, name + '.pth')
|
2023-01-16 23:08:59 +00:00
|
|
|
dest_onnx = path.join(model_path, name + '.onnx')
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('converting GFPGAN model: %s -> %s', name, dest_onnx)
|
2023-01-16 23:08:59 +00:00
|
|
|
|
|
|
|
if path.isfile(dest_onnx):
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('ONNX model already exists, skipping.')
|
2023-01-16 23:08:59 +00:00
|
|
|
return
|
|
|
|
|
2023-01-16 15:57:59 +00:00
|
|
|
if not path.isfile(dest_path):
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('PTH model not found, downloading...')
|
2023-01-17 02:53:12 +00:00
|
|
|
download_path = load_file_from_url(
|
2023-01-17 03:55:54 +00:00
|
|
|
url=url, model_dir=dest_path + '-cache', progress=True, file_name=None)
|
2023-01-17 02:53:12 +00:00
|
|
|
copyfile(download_path, dest_path)
|
2023-01-16 15:57:59 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('loading and training model')
|
2023-01-16 15:57:59 +00:00
|
|
|
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
2023-01-17 03:36:50 +00:00
|
|
|
num_block=23, num_grow_ch=32, scale=scale)
|
2023-01-16 15:57:59 +00:00
|
|
|
|
2023-01-17 02:53:12 +00:00
|
|
|
torch_model = torch.load(dest_path, map_location=map_location)
|
2023-01-16 23:08:59 +00:00
|
|
|
# TODO: make sure strict=False is safe here
|
2023-01-17 02:53:12 +00:00
|
|
|
if 'params_ema' in torch_model:
|
|
|
|
model.load_state_dict(torch_model['params_ema'], strict=False)
|
|
|
|
else:
|
|
|
|
model.load_state_dict(torch_model['params'], strict=False)
|
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
model.to(training_device).train(False)
|
2023-01-16 15:57:59 +00:00
|
|
|
model.eval()
|
|
|
|
|
2023-01-22 01:59:58 +00:00
|
|
|
rng = torch.rand(1, 3, 64, 64, device=map_location)
|
2023-01-16 15:57:59 +00:00
|
|
|
input_names = ['data']
|
|
|
|
output_names = ['output']
|
|
|
|
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
|
|
|
'output': {2: 'width', 3: 'height'}}
|
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('exporting ONNX model to %s', dest_onnx)
|
2023-01-16 23:08:59 +00:00
|
|
|
export(
|
|
|
|
model,
|
|
|
|
rng,
|
|
|
|
dest_onnx,
|
|
|
|
input_names=input_names,
|
|
|
|
output_names=output_names,
|
|
|
|
dynamic_axes=dynamic_axes,
|
2023-01-16 23:48:50 +00:00
|
|
|
opset_version=opset,
|
2023-01-16 23:08:59 +00:00
|
|
|
export_params=True
|
|
|
|
)
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('GFPGAN exported to ONNX successfully.')
|
2023-01-16 23:48:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
def onnx_export(
|
|
|
|
model,
|
|
|
|
model_args: tuple,
|
|
|
|
output_path: Path,
|
|
|
|
ordered_input_names,
|
|
|
|
output_names,
|
|
|
|
dynamic_axes,
|
|
|
|
opset,
|
|
|
|
use_external_data_format=False,
|
|
|
|
):
|
|
|
|
'''
|
|
|
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
|
|
|
'''
|
|
|
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
export(
|
|
|
|
model,
|
|
|
|
model_args,
|
|
|
|
f=output_path.as_posix(),
|
|
|
|
input_names=ordered_input_names,
|
|
|
|
output_names=output_names,
|
|
|
|
dynamic_axes=dynamic_axes,
|
|
|
|
do_constant_folding=True,
|
|
|
|
opset_version=opset,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
2023-01-29 21:23:01 +00:00
|
|
|
def convert_diffuser(name: str, url: str, opset: int, half: bool, token: str, single_vae: bool = False):
|
2023-01-16 23:48:50 +00:00
|
|
|
'''
|
|
|
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
|
|
|
'''
|
|
|
|
dtype = torch.float16 if half else torch.float32
|
|
|
|
dest_path = path.join(model_path, name)
|
2023-01-21 14:52:35 +00:00
|
|
|
|
|
|
|
# diffusers go into a directory rather than .onnx file
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('converting Diffusers model: %s -> %s/', name, dest_path)
|
2023-01-16 23:48:50 +00:00
|
|
|
|
2023-01-29 21:23:01 +00:00
|
|
|
if single_vae:
|
|
|
|
logger.info('converting model with single VAE')
|
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
if path.isdir(dest_path):
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('ONNX model already exists, skipping.')
|
2023-01-16 23:48:50 +00:00
|
|
|
return
|
2023-01-16 15:57:59 +00:00
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
if half and training_device != 'cuda':
|
|
|
|
raise ValueError(
|
|
|
|
'Half precision model export is only supported on GPUs with CUDA')
|
2023-01-16 15:57:59 +00:00
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
2023-01-21 21:08:19 +00:00
|
|
|
url, torch_dtype=dtype, use_auth_token=token).to(training_device)
|
2023-01-16 23:48:50 +00:00
|
|
|
output_path = Path(dest_path)
|
|
|
|
|
|
|
|
# TEXT ENCODER
|
|
|
|
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
|
|
|
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
|
|
|
text_input = pipeline.tokenizer(
|
|
|
|
"A sample prompt",
|
|
|
|
padding="max_length",
|
|
|
|
max_length=pipeline.tokenizer.model_max_length,
|
|
|
|
truncation=True,
|
|
|
|
return_tensors="pt",
|
|
|
|
)
|
|
|
|
onnx_export(
|
|
|
|
pipeline.text_encoder,
|
|
|
|
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
|
|
|
model_args=(text_input.input_ids.to(
|
|
|
|
device=training_device, dtype=torch.int32)),
|
|
|
|
output_path=output_path / "text_encoder" / "model.onnx",
|
|
|
|
ordered_input_names=["input_ids"],
|
|
|
|
output_names=["last_hidden_state", "pooler_output"],
|
|
|
|
dynamic_axes={
|
|
|
|
"input_ids": {0: "batch", 1: "sequence"},
|
|
|
|
},
|
|
|
|
opset=opset,
|
|
|
|
)
|
|
|
|
del pipeline.text_encoder
|
|
|
|
|
2023-01-31 23:09:23 +00:00
|
|
|
logger.debug('UNET config: %s', pipeline.unet.config)
|
2023-01-30 00:42:05 +00:00
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
# UNET
|
2023-01-30 00:42:05 +00:00
|
|
|
if single_vae:
|
|
|
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "class_labels"]
|
|
|
|
# unet_inputs = ["latent_model_input", "timestep", "encoder_hidden_states", "class_labels"]
|
|
|
|
unet_scale = 4
|
|
|
|
else:
|
|
|
|
unet_inputs = ["sample", "timestep", "encoder_hidden_states", "return_dict"]
|
|
|
|
unet_scale = False
|
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
unet_in_channels = pipeline.unet.config.in_channels
|
|
|
|
unet_sample_size = pipeline.unet.config.sample_size
|
|
|
|
unet_path = output_path / "unet" / "model.onnx"
|
|
|
|
onnx_export(
|
|
|
|
pipeline.unet,
|
|
|
|
model_args=(
|
|
|
|
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
|
|
|
device=training_device, dtype=dtype),
|
|
|
|
torch.randn(2).to(device=training_device, dtype=dtype),
|
|
|
|
torch.randn(2, num_tokens, text_hidden_size).to(
|
|
|
|
device=training_device, dtype=dtype),
|
2023-01-30 00:42:05 +00:00
|
|
|
unet_scale,
|
2023-01-16 23:48:50 +00:00
|
|
|
),
|
|
|
|
output_path=unet_path,
|
2023-01-30 00:42:05 +00:00
|
|
|
ordered_input_names=unet_inputs,
|
2023-01-16 23:48:50 +00:00
|
|
|
# has to be different from "sample" for correct tracing
|
|
|
|
output_names=["out_sample"],
|
|
|
|
dynamic_axes={
|
|
|
|
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
|
|
|
"timestep": {0: "batch"},
|
|
|
|
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
|
|
|
},
|
|
|
|
opset=opset,
|
|
|
|
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
|
|
|
)
|
|
|
|
unet_model_path = str(unet_path.absolute().as_posix())
|
|
|
|
unet_dir = path.dirname(unet_model_path)
|
|
|
|
unet = load(unet_model_path)
|
|
|
|
# clean up existing tensor files
|
|
|
|
rmtree(unet_dir)
|
|
|
|
mkdir(unet_dir)
|
|
|
|
# collate external tensor files into one
|
|
|
|
save_model(
|
|
|
|
unet,
|
|
|
|
unet_model_path,
|
|
|
|
save_as_external_data=True,
|
|
|
|
all_tensors_to_one_file=True,
|
|
|
|
location="weights.pb",
|
|
|
|
convert_attribute=False,
|
|
|
|
)
|
|
|
|
del pipeline.unet
|
|
|
|
|
2023-01-29 21:23:01 +00:00
|
|
|
if single_vae:
|
2023-01-31 23:09:23 +00:00
|
|
|
logger.debug('VAE config: %s', pipeline.vae.config)
|
2023-01-30 00:42:05 +00:00
|
|
|
|
2023-01-29 21:23:01 +00:00
|
|
|
# SINGLE VAE
|
|
|
|
vae_only = pipeline.vae
|
2023-01-30 00:42:05 +00:00
|
|
|
vae_latent_channels = vae_only.config.latent_channels
|
|
|
|
vae_out_channels = vae_only.config.out_channels
|
|
|
|
# forward only through the decoder part
|
|
|
|
vae_only.forward = vae_only.decode
|
2023-01-29 21:23:01 +00:00
|
|
|
onnx_export(
|
|
|
|
vae_only,
|
|
|
|
model_args=(
|
2023-01-30 00:42:05 +00:00
|
|
|
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
2023-01-29 21:23:01 +00:00
|
|
|
device=training_device, dtype=dtype),
|
|
|
|
False,
|
|
|
|
),
|
|
|
|
output_path=output_path / "vae" / "model.onnx",
|
2023-01-30 00:42:05 +00:00
|
|
|
ordered_input_names=["latent_sample", "return_dict"],
|
|
|
|
output_names=["sample"],
|
2023-01-29 21:23:01 +00:00
|
|
|
dynamic_axes={
|
2023-01-30 00:42:05 +00:00
|
|
|
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
2023-01-29 21:23:01 +00:00
|
|
|
},
|
|
|
|
opset=opset,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# VAE ENCODER
|
|
|
|
vae_encoder = pipeline.vae
|
|
|
|
vae_in_channels = vae_encoder.config.in_channels
|
|
|
|
vae_sample_size = vae_encoder.config.sample_size
|
|
|
|
# need to get the raw tensor output (sample) from the encoder
|
|
|
|
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
|
|
|
sample, return_dict)[0].sample()
|
|
|
|
onnx_export(
|
|
|
|
vae_encoder,
|
|
|
|
model_args=(
|
|
|
|
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
|
|
|
device=training_device, dtype=dtype),
|
|
|
|
False,
|
|
|
|
),
|
|
|
|
output_path=output_path / "vae_encoder" / "model.onnx",
|
|
|
|
ordered_input_names=["sample", "return_dict"],
|
|
|
|
output_names=["latent_sample"],
|
|
|
|
dynamic_axes={
|
|
|
|
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
|
|
|
},
|
|
|
|
opset=opset,
|
|
|
|
)
|
|
|
|
|
|
|
|
# VAE DECODER
|
|
|
|
vae_decoder = pipeline.vae
|
|
|
|
vae_latent_channels = vae_decoder.config.latent_channels
|
|
|
|
vae_out_channels = vae_decoder.config.out_channels
|
|
|
|
# forward only through the decoder part
|
|
|
|
vae_decoder.forward = vae_encoder.decode
|
|
|
|
onnx_export(
|
|
|
|
vae_decoder,
|
|
|
|
model_args=(
|
|
|
|
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
|
|
|
device=training_device, dtype=dtype),
|
|
|
|
False,
|
|
|
|
),
|
|
|
|
output_path=output_path / "vae_decoder" / "model.onnx",
|
|
|
|
ordered_input_names=["latent_sample", "return_dict"],
|
|
|
|
output_names=["sample"],
|
|
|
|
dynamic_axes={
|
|
|
|
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
|
|
|
},
|
|
|
|
opset=opset,
|
|
|
|
)
|
2023-01-16 23:48:50 +00:00
|
|
|
|
|
|
|
del pipeline.vae
|
|
|
|
|
|
|
|
# SAFETY CHECKER
|
|
|
|
if pipeline.safety_checker is not None:
|
|
|
|
safety_checker = pipeline.safety_checker
|
|
|
|
clip_num_channels = safety_checker.config.vision_config.num_channels
|
|
|
|
clip_image_size = safety_checker.config.vision_config.image_size
|
|
|
|
safety_checker.forward = safety_checker.forward_onnx
|
|
|
|
onnx_export(
|
|
|
|
pipeline.safety_checker,
|
|
|
|
model_args=(
|
|
|
|
torch.randn(
|
|
|
|
1,
|
|
|
|
clip_num_channels,
|
|
|
|
clip_image_size,
|
|
|
|
clip_image_size,
|
|
|
|
).to(device=training_device, dtype=dtype),
|
|
|
|
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
|
|
|
|
device=training_device, dtype=dtype),
|
|
|
|
),
|
|
|
|
output_path=output_path / "safety_checker" / "model.onnx",
|
|
|
|
ordered_input_names=["clip_input", "images"],
|
|
|
|
output_names=["out_images", "has_nsfw_concepts"],
|
|
|
|
dynamic_axes={
|
|
|
|
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
|
|
|
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
|
|
|
},
|
|
|
|
opset=opset,
|
|
|
|
)
|
|
|
|
del pipeline.safety_checker
|
|
|
|
safety_checker = OnnxRuntimeModel.from_pretrained(
|
|
|
|
output_path / "safety_checker")
|
|
|
|
feature_extractor = pipeline.feature_extractor
|
|
|
|
else:
|
|
|
|
safety_checker = None
|
|
|
|
feature_extractor = None
|
|
|
|
|
2023-01-29 21:23:01 +00:00
|
|
|
if single_vae:
|
|
|
|
onnx_pipeline = StableDiffusionUpscalePipeline(
|
|
|
|
vae=OnnxRuntimeModel.from_pretrained(
|
|
|
|
output_path / "vae"),
|
|
|
|
text_encoder=OnnxRuntimeModel.from_pretrained(
|
|
|
|
output_path / "text_encoder"),
|
|
|
|
tokenizer=pipeline.tokenizer,
|
|
|
|
low_res_scheduler=pipeline.scheduler,
|
|
|
|
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
|
|
|
scheduler=pipeline.scheduler,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
onnx_pipeline = OnnxStableDiffusionPipeline(
|
|
|
|
vae_encoder=OnnxRuntimeModel.from_pretrained(
|
|
|
|
output_path / "vae_encoder"),
|
|
|
|
vae_decoder=OnnxRuntimeModel.from_pretrained(
|
|
|
|
output_path / "vae_decoder"),
|
|
|
|
text_encoder=OnnxRuntimeModel.from_pretrained(
|
|
|
|
output_path / "text_encoder"),
|
|
|
|
tokenizer=pipeline.tokenizer,
|
|
|
|
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
|
|
|
scheduler=pipeline.scheduler,
|
|
|
|
safety_checker=safety_checker,
|
|
|
|
feature_extractor=feature_extractor,
|
|
|
|
requires_safety_checker=safety_checker is not None,
|
|
|
|
)
|
2023-01-16 23:48:50 +00:00
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('exporting ONNX model')
|
2023-01-21 14:52:35 +00:00
|
|
|
|
2023-01-16 23:48:50 +00:00
|
|
|
onnx_pipeline.save_pretrained(output_path)
|
2023-01-30 00:42:05 +00:00
|
|
|
logger.info("ONNX pipeline saved to %s", output_path)
|
2023-01-16 23:48:50 +00:00
|
|
|
|
|
|
|
del pipeline
|
|
|
|
del onnx_pipeline
|
2023-01-29 21:23:01 +00:00
|
|
|
|
|
|
|
if single_vae:
|
|
|
|
_ = StableDiffusionUpscalePipeline.from_pretrained(
|
|
|
|
output_path, provider="CPUExecutionProvider"
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
|
|
|
output_path, provider="CPUExecutionProvider")
|
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info("ONNX pipeline is loadable")
|
2023-01-16 15:57:59 +00:00
|
|
|
|
|
|
|
|
2023-01-21 14:52:35 +00:00
|
|
|
def load_models(args, models: Models):
|
|
|
|
if args.diffusion:
|
|
|
|
for source in models.get('diffusion'):
|
|
|
|
if source[0] in args.skip:
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('Skipping model: %s', source[0])
|
2023-01-21 14:52:35 +00:00
|
|
|
else:
|
2023-01-29 21:23:01 +00:00
|
|
|
single_vae = 'upscaling' in source[0]
|
|
|
|
convert_diffuser(*source, args.opset, args.half, args.token, single_vae=single_vae)
|
2023-01-21 14:52:35 +00:00
|
|
|
|
|
|
|
if args.upscaling:
|
|
|
|
for source in models.get('upscaling'):
|
|
|
|
if source[0] in args.skip:
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('Skipping model: %s', source[0])
|
2023-01-21 14:52:35 +00:00
|
|
|
else:
|
|
|
|
convert_real_esrgan(*source, args.opset)
|
|
|
|
|
|
|
|
if args.correction:
|
|
|
|
for source in models.get('correction'):
|
|
|
|
if source[0] in args.skip:
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('Skipping model: %s', source[0])
|
2023-01-21 14:52:35 +00:00
|
|
|
else:
|
|
|
|
convert_gfpgan(*source, args.opset)
|
|
|
|
|
|
|
|
|
2023-01-16 15:57:59 +00:00
|
|
|
def main() -> int:
|
|
|
|
parser = ArgumentParser(
|
|
|
|
prog='onnx-web model converter',
|
|
|
|
description='convert checkpoint models to ONNX')
|
2023-01-16 23:55:58 +00:00
|
|
|
|
|
|
|
# model groups
|
2023-01-21 14:52:35 +00:00
|
|
|
parser.add_argument('--correction', action='store_true', default=False)
|
|
|
|
parser.add_argument('--diffusion', action='store_true', default=False)
|
|
|
|
parser.add_argument('--upscaling', action='store_true', default=False)
|
2023-02-02 14:59:47 +00:00
|
|
|
|
|
|
|
# extra models
|
|
|
|
parser.add_argument('--extras', nargs='*', type=str, default=[])
|
2023-01-16 23:55:58 +00:00
|
|
|
parser.add_argument('--skip', nargs='*', type=str, default=[])
|
|
|
|
|
|
|
|
# export options
|
2023-01-16 23:48:50 +00:00
|
|
|
parser.add_argument(
|
|
|
|
'--half',
|
|
|
|
action='store_true',
|
|
|
|
default=False,
|
2023-02-02 14:59:47 +00:00
|
|
|
help='Export models for half precision, faster on some Nvidia cards.'
|
2023-01-16 23:48:50 +00:00
|
|
|
)
|
2023-01-16 23:55:58 +00:00
|
|
|
parser.add_argument(
|
|
|
|
'--opset',
|
|
|
|
default=14,
|
|
|
|
type=int,
|
|
|
|
help="The version of the ONNX operator set to use.",
|
|
|
|
)
|
2023-01-21 21:08:19 +00:00
|
|
|
parser.add_argument(
|
|
|
|
'--token',
|
|
|
|
type=str,
|
|
|
|
help="HuggingFace token with read permissions for downloading models.",
|
|
|
|
)
|
2023-01-16 23:48:50 +00:00
|
|
|
|
2023-01-16 15:57:59 +00:00
|
|
|
args = parser.parse_args()
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('CLI arguments: %s', args)
|
2023-01-16 15:57:59 +00:00
|
|
|
|
2023-02-01 01:32:34 +00:00
|
|
|
if not path.exists(model_path):
|
2023-01-31 23:09:23 +00:00
|
|
|
logger.info('Model path does not existing, creating: %s', model_path)
|
|
|
|
makedirs(model_path)
|
|
|
|
|
2023-01-28 23:09:19 +00:00
|
|
|
logger.info('Converting base models.')
|
2023-01-21 14:52:35 +00:00
|
|
|
load_models(args, base_models)
|
2023-01-16 15:57:59 +00:00
|
|
|
|
2023-02-02 14:59:47 +00:00
|
|
|
for file in args.extras:
|
|
|
|
logger.info('Loading extra models from %s', file)
|
|
|
|
try:
|
|
|
|
with open(file, 'r') as f:
|
|
|
|
data = loads(f.read())
|
|
|
|
logger.info('Converting extra models.')
|
|
|
|
load_models(args, data)
|
|
|
|
except Exception as err:
|
|
|
|
logger.error('Error converting extra models: %s', err)
|
2023-01-16 15:57:59 +00:00
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
exit(main())
|