feat(api): add support for Stable Diffusion models to conversion script
This commit is contained in:
parent
4d0898a52c
commit
decb2813c6
|
@ -1,15 +1,25 @@
|
|||
from argparse import ArgumentParser
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from os import path, environ
|
||||
from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline, StableDiffusionPipeline
|
||||
from onnx import load, save_model
|
||||
from os import mkdir, path, environ
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from sys import exit
|
||||
from torch.onnx import export
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
sources = {
|
||||
sources: Dict[str, List[Tuple[str, str]]] = {
|
||||
'diffusers': [
|
||||
# TODO: add stable diffusion models
|
||||
# v1.x
|
||||
('stable-diffusion-onnx-v1-5', 'runwayml/stable-diffusion-v1-5'),
|
||||
('stable-diffusion-onnx-v1-inpainting', 'runwayml/stable-diffusion-inpainting'),
|
||||
# v2.x
|
||||
('stable-diffusion-onnx-v2-1', 'stabilityai/stable-diffusion-2-1'),
|
||||
('stable-diffusion-onnx-v2-inpainting', 'stabilityai/stable-diffusion-2-inpainting'),
|
||||
],
|
||||
'gfpgan': [
|
||||
('GFPGANv1.3', 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'),
|
||||
|
@ -23,14 +33,17 @@ model_path = environ.get('ONNX_WEB_MODEL_PATH',
|
|||
path.join('..', 'models'))
|
||||
|
||||
|
||||
training_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_real_esrgan(name: str, url: str):
|
||||
def convert_real_esrgan(name: str, url: str, opset: int):
|
||||
dest_path = path.join(model_path, name)
|
||||
dest_onnx = path.join(model_path, name + '.onnx')
|
||||
print('converting Real ESRGAN into %s' % dest_path)
|
||||
print('converting Real ESRGAN model: %s -> %s' % (name, dest_path))
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
print('Real ESRGAN ONNX model already exists, skipping.')
|
||||
print('ONNX model already exists, skipping.')
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
|
@ -38,11 +51,11 @@ def convert_real_esrgan(name: str, url: str):
|
|||
dest_path = load_file_from_url(
|
||||
url=url, model_dir=path.join(dest_path, name), progress=True, file_name=None)
|
||||
|
||||
print('loading and training Real ESRGAN model')
|
||||
print('loading and training model')
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=4)
|
||||
model.load_state_dict(torch.load(dest_path)['params_ema'])
|
||||
model.train(False)
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64)
|
||||
|
@ -51,7 +64,7 @@ def convert_real_esrgan(name: str, url: str):
|
|||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
||||
'output': {2: 'width', 3: 'height'}}
|
||||
|
||||
print('exporting Real ESRGAN model to %s' % dest_onnx)
|
||||
print('exporting ONNX model to %s' % dest_onnx)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
|
@ -59,34 +72,34 @@ def convert_real_esrgan(name: str, url: str):
|
|||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=11,
|
||||
opset_version=opset,
|
||||
export_params=True
|
||||
)
|
||||
print('Real ESRGAN exported to ONNX')
|
||||
print('Real ESRGAN exported to ONNX successfully.')
|
||||
|
||||
|
||||
def convert_gfpgan(name: str, url: str):
|
||||
@torch.no_grad()
|
||||
def convert_gfpgan(name: str, url: str, opset: int):
|
||||
dest_path = path.join(model_path, name)
|
||||
dest_onnx = path.join(model_path, name + '.onnx')
|
||||
|
||||
print('converting GFPGAN into %s' % dest_path)
|
||||
print('converting GFPGAN model: %s -> %s' % (name, dest_path))
|
||||
|
||||
if path.isfile(dest_onnx):
|
||||
print('GFPGAN ONNX model already exists, skipping.')
|
||||
print('ONNX model already exists, skipping.')
|
||||
return
|
||||
|
||||
if not path.isfile(dest_path):
|
||||
print('existing model not found, downloading...')
|
||||
print('PTH model not found, downloading...')
|
||||
dest_path = load_file_from_url(
|
||||
url=url, model_dir=path.join(dest_path, name), progress=True, file_name=None)
|
||||
|
||||
print('loading and training GFPGAN model')
|
||||
print('loading and training model')
|
||||
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64,
|
||||
num_block=23, num_grow_ch=32, scale=4)
|
||||
|
||||
# TODO: make sure strict=False is safe here
|
||||
model.load_state_dict(torch.load(dest_path)['params_ema'], strict=False)
|
||||
model.train(False)
|
||||
model.to(training_device).train(False)
|
||||
model.eval()
|
||||
|
||||
rng = torch.rand(1, 3, 64, 64)
|
||||
|
@ -95,7 +108,7 @@ def convert_gfpgan(name: str, url: str):
|
|||
dynamic_axes = {'data': {2: 'width', 3: 'height'},
|
||||
'output': {2: 'width', 3: 'height'}}
|
||||
|
||||
print('exporting GFPGAN model to %s' % dest_onnx)
|
||||
print('exporting ONNX model to %s' % dest_onnx)
|
||||
export(
|
||||
model,
|
||||
rng,
|
||||
|
@ -103,13 +116,233 @@ def convert_gfpgan(name: str, url: str):
|
|||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=11,
|
||||
opset_version=opset,
|
||||
export_params=True
|
||||
)
|
||||
print('GFPGAN exported to ONNX')
|
||||
print('GFPGAN exported to ONNX successfully.')
|
||||
|
||||
|
||||
def convert_diffuser():
|
||||
def onnx_export(
|
||||
model,
|
||||
model_args: tuple,
|
||||
output_path: Path,
|
||||
ordered_input_names,
|
||||
output_names,
|
||||
dynamic_axes,
|
||||
opset,
|
||||
use_external_data_format=False,
|
||||
):
|
||||
'''
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
'''
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
export(
|
||||
model,
|
||||
model_args,
|
||||
f=output_path.as_posix(),
|
||||
input_names=ordered_input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
do_constant_folding=True,
|
||||
opset_version=opset,
|
||||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_diffuser(name: str, url: str, opset: int, half: bool):
|
||||
'''
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
'''
|
||||
dtype = torch.float16 if half else torch.float32
|
||||
dest_path = path.join(model_path, name)
|
||||
print('converting Diffusers model: %s -> %s' % (name, dest_path))
|
||||
|
||||
if path.isdir(dest_path):
|
||||
print('ONNX model already exists, skipping.')
|
||||
return
|
||||
|
||||
if half and training_device != 'cuda':
|
||||
raise ValueError(
|
||||
'Half precision model export is only supported on GPUs with CUDA')
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
url, torch_dtype=dtype).to(training_device)
|
||||
output_path = Path(dest_path)
|
||||
|
||||
# TEXT ENCODER
|
||||
num_tokens = pipeline.text_encoder.config.max_position_embeddings
|
||||
text_hidden_size = pipeline.text_encoder.config.hidden_size
|
||||
text_input = pipeline.tokenizer(
|
||||
"A sample prompt",
|
||||
padding="max_length",
|
||||
max_length=pipeline.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
onnx_export(
|
||||
pipeline.text_encoder,
|
||||
# casting to torch.int32 until the CLIP fix is released: https://github.com/huggingface/transformers/pull/18515/files
|
||||
model_args=(text_input.input_ids.to(
|
||||
device=training_device, dtype=torch.int32)),
|
||||
output_path=output_path / "text_encoder" / "model.onnx",
|
||||
ordered_input_names=["input_ids"],
|
||||
output_names=["last_hidden_state", "pooler_output"],
|
||||
dynamic_axes={
|
||||
"input_ids": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.text_encoder
|
||||
|
||||
# UNET
|
||||
unet_in_channels = pipeline.unet.config.in_channels
|
||||
unet_sample_size = pipeline.unet.config.sample_size
|
||||
unet_path = output_path / "unet" / "model.onnx"
|
||||
onnx_export(
|
||||
pipeline.unet,
|
||||
model_args=(
|
||||
torch.randn(2, unet_in_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
torch.randn(2).to(device=training_device, dtype=dtype),
|
||||
torch.randn(2, num_tokens, text_hidden_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=unet_path,
|
||||
ordered_input_names=["sample", "timestep",
|
||||
"encoder_hidden_states", "return_dict"],
|
||||
# has to be different from "sample" for correct tracing
|
||||
output_names=["out_sample"],
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"timestep": {0: "batch"},
|
||||
"encoder_hidden_states": {0: "batch", 1: "sequence"},
|
||||
},
|
||||
opset=opset,
|
||||
use_external_data_format=True, # UNet is > 2GB, so the weights need to be split
|
||||
)
|
||||
unet_model_path = str(unet_path.absolute().as_posix())
|
||||
unet_dir = path.dirname(unet_model_path)
|
||||
unet = load(unet_model_path)
|
||||
# clean up existing tensor files
|
||||
rmtree(unet_dir)
|
||||
mkdir(unet_dir)
|
||||
# collate external tensor files into one
|
||||
save_model(
|
||||
unet,
|
||||
unet_model_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
convert_attribute=False,
|
||||
)
|
||||
del pipeline.unet
|
||||
|
||||
# VAE ENCODER
|
||||
vae_encoder = pipeline.vae
|
||||
vae_in_channels = vae_encoder.config.in_channels
|
||||
vae_sample_size = vae_encoder.config.sample_size
|
||||
# need to get the raw tensor output (sample) from the encoder
|
||||
vae_encoder.forward = lambda sample, return_dict: vae_encoder.encode(
|
||||
sample, return_dict)[0].sample()
|
||||
onnx_export(
|
||||
vae_encoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_in_channels, vae_sample_size, vae_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_encoder" / "model.onnx",
|
||||
ordered_input_names=["sample", "return_dict"],
|
||||
output_names=["latent_sample"],
|
||||
dynamic_axes={
|
||||
"sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
|
||||
# VAE DECODER
|
||||
vae_decoder = pipeline.vae
|
||||
vae_latent_channels = vae_decoder.config.latent_channels
|
||||
vae_out_channels = vae_decoder.config.out_channels
|
||||
# forward only through the decoder part
|
||||
vae_decoder.forward = vae_encoder.decode
|
||||
onnx_export(
|
||||
vae_decoder,
|
||||
model_args=(
|
||||
torch.randn(1, vae_latent_channels, unet_sample_size, unet_sample_size).to(
|
||||
device=training_device, dtype=dtype),
|
||||
False,
|
||||
),
|
||||
output_path=output_path / "vae_decoder" / "model.onnx",
|
||||
ordered_input_names=["latent_sample", "return_dict"],
|
||||
output_names=["sample"],
|
||||
dynamic_axes={
|
||||
"latent_sample": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.vae
|
||||
|
||||
# SAFETY CHECKER
|
||||
if pipeline.safety_checker is not None:
|
||||
safety_checker = pipeline.safety_checker
|
||||
clip_num_channels = safety_checker.config.vision_config.num_channels
|
||||
clip_image_size = safety_checker.config.vision_config.image_size
|
||||
safety_checker.forward = safety_checker.forward_onnx
|
||||
onnx_export(
|
||||
pipeline.safety_checker,
|
||||
model_args=(
|
||||
torch.randn(
|
||||
1,
|
||||
clip_num_channels,
|
||||
clip_image_size,
|
||||
clip_image_size,
|
||||
).to(device=training_device, dtype=dtype),
|
||||
torch.randn(1, vae_sample_size, vae_sample_size, vae_out_channels).to(
|
||||
device=training_device, dtype=dtype),
|
||||
),
|
||||
output_path=output_path / "safety_checker" / "model.onnx",
|
||||
ordered_input_names=["clip_input", "images"],
|
||||
output_names=["out_images", "has_nsfw_concepts"],
|
||||
dynamic_axes={
|
||||
"clip_input": {0: "batch", 1: "channels", 2: "height", 3: "width"},
|
||||
"images": {0: "batch", 1: "height", 2: "width", 3: "channels"},
|
||||
},
|
||||
opset=opset,
|
||||
)
|
||||
del pipeline.safety_checker
|
||||
safety_checker = OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "safety_checker")
|
||||
feature_extractor = pipeline.feature_extractor
|
||||
else:
|
||||
safety_checker = None
|
||||
feature_extractor = None
|
||||
|
||||
onnx_pipeline = OnnxStableDiffusionPipeline(
|
||||
vae_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae_encoder"),
|
||||
vae_decoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "vae_decoder"),
|
||||
text_encoder=OnnxRuntimeModel.from_pretrained(
|
||||
output_path / "text_encoder"),
|
||||
tokenizer=pipeline.tokenizer,
|
||||
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
|
||||
scheduler=pipeline.scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
requires_safety_checker=safety_checker is not None,
|
||||
)
|
||||
|
||||
onnx_pipeline.save_pretrained(output_path)
|
||||
print("ONNX pipeline saved to", output_path)
|
||||
|
||||
del pipeline
|
||||
del onnx_pipeline
|
||||
_ = OnnxStableDiffusionPipeline.from_pretrained(
|
||||
output_path, provider="CPUExecutionProvider")
|
||||
print("ONNX pipeline is loadable")
|
||||
pass
|
||||
|
||||
|
||||
|
@ -117,27 +350,36 @@ def main() -> int:
|
|||
parser = ArgumentParser(
|
||||
prog='onnx-web model converter',
|
||||
description='convert checkpoint models to ONNX')
|
||||
parser.add_argument('--diffusers', type=str, nargs='*',
|
||||
help='models using the diffusers pipeline')
|
||||
parser.add_argument('--gfpgan', action='store_true')
|
||||
parser.add_argument('--resrgan', action='store_true')
|
||||
parser.add_argument('--diffusers', action='store_true', default=True)
|
||||
parser.add_argument('--gfpgan', action='store_true', default=False)
|
||||
parser.add_argument('--resrgan', action='store_true', default=False)
|
||||
parser.add_argument(
|
||||
'--opset',
|
||||
default=14,
|
||||
type=int,
|
||||
help="The version of the ONNX operator set to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--half',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Export models for half precision, faster on some Nvidia cards'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
if args.diffusers:
|
||||
for source in args.diffusers:
|
||||
print('converting Diffusers model: %s' % source[0])
|
||||
convert_diffuser(*source)
|
||||
for source in sources.get('diffusers'):
|
||||
convert_diffuser(*source, args.opset, args.half)
|
||||
|
||||
if args.resrgan:
|
||||
for source in sources.get('real_esrgan'):
|
||||
print('converting Real ESRGAN model: %s' % source[0])
|
||||
convert_real_esrgan(*source)
|
||||
convert_real_esrgan(*source, args.opset)
|
||||
|
||||
if args.gfpgan:
|
||||
for source in sources.get('gfpgan'):
|
||||
print('converting GFPGAN model: %s' % source[0])
|
||||
convert_gfpgan(*source)
|
||||
convert_gfpgan(*source, args.opset)
|
||||
|
||||
return 0
|
||||
|
||||
|
|
|
@ -167,8 +167,9 @@ def upscale_gfpgan(ctx: ServerContext, params: UpscaleParams, image, upsampler=N
|
|||
if upsampler is None:
|
||||
upsampler = make_resrgan(ctx, params, tile=512)
|
||||
|
||||
face_path = path.join(ctx.model_path, '%s.pth' % (params.face_model)) # TODO: convert to ONNX
|
||||
face_path = path.join(ctx.model_path, '%s.pth' % (params.face_model))
|
||||
|
||||
# TODO: doesn't have a model param, not sure how to pass ONNX model
|
||||
face_enhancer = GFPGANer(
|
||||
model_path=face_path,
|
||||
upscale=params.outscale,
|
||||
|
|
Loading…
Reference in New Issue