From cd06f9291b3f7ada56e6745df87a5aae26423338 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Mon, 11 Sep 2023 18:18:38 -0500 Subject: [PATCH] add custom VAE and fp16 support to SDXL conversion --- .../convert/diffusion/diffusion_xl.py | 33 +++++++++++++++++-- api/onnx_web/params.py | 2 +- api/onnx_web/server/params.py | 2 +- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index a7dcf104..6982d0b1 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -2,10 +2,14 @@ from logging import getLogger from os import path from typing import Dict, Optional, Tuple +import onnx import torch -from diffusers import StableDiffusionXLPipeline +from diffusers import AutoencoderKL, StableDiffusionXLPipeline +from onnx.shape_inference import infer_shapes_path +from onnxruntime.transformers.float16 import convert_float_to_float16 from optimum.exporters.onnx import main_export +from ...constants import ONNX_MODEL from ..utils import ConversionContext logger = getLogger(__name__) @@ -23,7 +27,7 @@ def convert_diffusion_diffusers_xl( From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py """ name = model.get("name") - # TODO: support alternate VAE + replace_vae = model.get("vae", None) device = conversion.training_device dtype = conversion.torch_dtype() @@ -57,6 +61,12 @@ def convert_diffusion_diffusers_xl( else: pipeline = StableDiffusionXLPipeline.from_pretrained(source) + if replace_vae is not None: + if replace_vae.endswith(".safetensors"): + pipeline.vae = AutoencoderKL.from_single_file(replace_vae) + else: + pipeline.vae = AutoencoderKL.from_pretrained(replace_vae) + pipeline.save_pretrained(temp_path) # directory -> onnx using optimum exporters @@ -69,6 +79,23 @@ def convert_diffusion_diffusers_xl( framework="pt", ) - # TODO: optimize UNet to fp16 + if conversion.half: + unet_path = path.join(dest_path, "unet", ONNX_MODEL) + infer_shapes_path(unet_path) + unet = onnx.load(unet_path) + opt_model = convert_float_to_float16( + unet, + disable_shape_infer=True, + force_fp16_initializers=True, + keep_io_types=True, + op_block_list=["Attention", "MultiHeadAttention"], + ) + onnx.save_model( + opt_model, + unet_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.pb", + ) return False, dest_path diff --git a/api/onnx_web/params.py b/api/onnx_web/params.py index a6f2c888..5b504c6a 100644 --- a/api/onnx_web/params.py +++ b/api/onnx_web/params.py @@ -24,7 +24,7 @@ class SizeChart(IntEnum): hd16k = 2**14 hd32k = 2**15 hd64k = 2**16 - unlimited = 2**32 # sort of + unlimited = 2**32 # sort of class TileOrder: diff --git a/api/onnx_web/server/params.py b/api/onnx_web/server/params.py index e37a15a8..c32fdbf2 100644 --- a/api/onnx_web/server/params.py +++ b/api/onnx_web/server/params.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Any, Dict, Tuple +from typing import Dict, Tuple import numpy as np from flask import request