add custom VAE and fp16 support to SDXL conversion
This commit is contained in:
parent
eeebdfebcb
commit
cd06f9291b
|
@ -2,10 +2,14 @@ from logging import getLogger
|
|||
from os import path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers import AutoencoderKL, StableDiffusionXLPipeline
|
||||
from onnx.shape_inference import infer_shapes_path
|
||||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||
from optimum.exporters.onnx import main_export
|
||||
|
||||
from ...constants import ONNX_MODEL
|
||||
from ..utils import ConversionContext
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -23,7 +27,7 @@ def convert_diffusion_diffusers_xl(
|
|||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
"""
|
||||
name = model.get("name")
|
||||
# TODO: support alternate VAE
|
||||
replace_vae = model.get("vae", None)
|
||||
|
||||
device = conversion.training_device
|
||||
dtype = conversion.torch_dtype()
|
||||
|
@ -57,6 +61,12 @@ def convert_diffusion_diffusers_xl(
|
|||
else:
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(source)
|
||||
|
||||
if replace_vae is not None:
|
||||
if replace_vae.endswith(".safetensors"):
|
||||
pipeline.vae = AutoencoderKL.from_single_file(replace_vae)
|
||||
else:
|
||||
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
|
||||
|
||||
pipeline.save_pretrained(temp_path)
|
||||
|
||||
# directory -> onnx using optimum exporters
|
||||
|
@ -69,6 +79,23 @@ def convert_diffusion_diffusers_xl(
|
|||
framework="pt",
|
||||
)
|
||||
|
||||
# TODO: optimize UNet to fp16
|
||||
if conversion.half:
|
||||
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
|
||||
infer_shapes_path(unet_path)
|
||||
unet = onnx.load(unet_path)
|
||||
opt_model = convert_float_to_float16(
|
||||
unet,
|
||||
disable_shape_infer=True,
|
||||
force_fp16_initializers=True,
|
||||
keep_io_types=True,
|
||||
op_block_list=["Attention", "MultiHeadAttention"],
|
||||
)
|
||||
onnx.save_model(
|
||||
opt_model,
|
||||
unet_path,
|
||||
save_as_external_data=True,
|
||||
all_tensors_to_one_file=True,
|
||||
location="weights.pb",
|
||||
)
|
||||
|
||||
return False, dest_path
|
||||
|
|
|
@ -24,7 +24,7 @@ class SizeChart(IntEnum):
|
|||
hd16k = 2**14
|
||||
hd32k = 2**15
|
||||
hd64k = 2**16
|
||||
unlimited = 2**32 # sort of
|
||||
unlimited = 2**32 # sort of
|
||||
|
||||
|
||||
class TileOrder:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from logging import getLogger
|
||||
from typing import Any, Dict, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from flask import request
|
||||
|
|
Loading…
Reference in New Issue