1
0
Fork 0

add custom VAE and fp16 support to SDXL conversion

This commit is contained in:
Sean Sube 2023-09-11 18:18:38 -05:00
parent eeebdfebcb
commit cd06f9291b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 32 additions and 5 deletions

View File

@ -2,10 +2,14 @@ from logging import getLogger
from os import path
from typing import Dict, Optional, Tuple
import onnx
import torch
from diffusers import StableDiffusionXLPipeline
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16
from optimum.exporters.onnx import main_export
from ...constants import ONNX_MODEL
from ..utils import ConversionContext
logger = getLogger(__name__)
@ -23,7 +27,7 @@ def convert_diffusion_diffusers_xl(
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
"""
name = model.get("name")
# TODO: support alternate VAE
replace_vae = model.get("vae", None)
device = conversion.training_device
dtype = conversion.torch_dtype()
@ -57,6 +61,12 @@ def convert_diffusion_diffusers_xl(
else:
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
if replace_vae is not None:
if replace_vae.endswith(".safetensors"):
pipeline.vae = AutoencoderKL.from_single_file(replace_vae)
else:
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
pipeline.save_pretrained(temp_path)
# directory -> onnx using optimum exporters
@ -69,6 +79,23 @@ def convert_diffusion_diffusers_xl(
framework="pt",
)
# TODO: optimize UNet to fp16
if conversion.half:
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
infer_shapes_path(unet_path)
unet = onnx.load(unet_path)
opt_model = convert_float_to_float16(
unet,
disable_shape_infer=True,
force_fp16_initializers=True,
keep_io_types=True,
op_block_list=["Attention", "MultiHeadAttention"],
)
onnx.save_model(
opt_model,
unet_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
)
return False, dest_path

View File

@ -1,5 +1,5 @@
from logging import getLogger
from typing import Any, Dict, Tuple
from typing import Dict, Tuple
import numpy as np
from flask import request