1
0
Fork 0

feat(api): add new optimum-based SD converter

This commit is contained in:
Sean Sube 2023-12-23 22:09:57 -06:00
parent 2b8b59a39c
commit b6ef00e437
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
3 changed files with 120 additions and 2 deletions

View File

@ -17,7 +17,10 @@ from .client import add_model_source, fetch_model
from .client.huggingface import HuggingfaceClient from .client.huggingface import HuggingfaceClient
from .correction.gfpgan import convert_correction_gfpgan from .correction.gfpgan import convert_correction_gfpgan
from .diffusion.control import convert_diffusion_control from .diffusion.control import convert_diffusion_control
from .diffusion.diffusion import convert_diffusion_diffusers from .diffusion.diffusion import (
convert_diffusion_diffusers,
convert_diffusion_diffusers_optimum,
)
from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl from .diffusion.diffusion_xl import convert_diffusion_diffusers_xl
from .diffusion.lora import blend_loras from .diffusion.lora import blend_loras
from .diffusion.textual_inversion import blend_textual_inversions from .diffusion.textual_inversion import blend_textual_inversions
@ -60,6 +63,7 @@ model_converters: Dict[str, Any] = {
"img2img-sdxl": convert_diffusion_diffusers_xl, "img2img-sdxl": convert_diffusion_diffusers_xl,
"inpaint": convert_diffusion_diffusers, "inpaint": convert_diffusion_diffusers,
"txt2img": convert_diffusion_diffusers, "txt2img": convert_diffusion_diffusers,
"txt2img-optimum": convert_diffusion_diffusers_optimum,
"txt2img-sdxl": convert_diffusion_diffusers_xl, "txt2img-sdxl": convert_diffusion_diffusers_xl,
} }

View File

@ -25,6 +25,9 @@ from diffusers import (
StableDiffusionUpscalePipeline, StableDiffusionUpscalePipeline,
) )
from onnx import load_model, save_model from onnx import load_model, save_model
from onnx.shape_inference import infer_shapes_path
from onnxruntime.transformers.float16 import convert_float_to_float16
from optimum.exporters.onnx import main_export
from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ...constants import ONNX_MODEL, ONNX_WEIGHTS
from ...diffusers.load import optimize_pipeline from ...diffusers.load import optimize_pipeline
@ -751,3 +754,114 @@ def convert_diffusion_diffusers(
logger.debug("skipping ONNX reload test") logger.debug("skipping ONNX reload test")
return (True, dest_path) return (True, dest_path)
@torch.no_grad()
def convert_diffusion_diffusers_optimum(
conversion: ConversionContext,
model: Dict,
format: Optional[str],
) -> Tuple[bool, str]:
name = str(model.get("name")).strip()
source = model.get("source")
# optional
image_size = model.get("image_size", None)
pipe_type = model.get("pipeline", "txt2img")
replace_vae = model.get("vae", None)
version = model.get("version", None)
device = conversion.training_device
dtype = conversion.torch_dtype()
logger.debug("using Torch dtype %s for pipeline", dtype)
dest_path = path.join(conversion.model_path, name)
model_index = path.join(dest_path, "model_index.json")
model_hash = path.join(dest_path, "hash.txt")
# diffusers go into a directory rather than .onnx file
logger.info(
"converting Stable Diffusion model %s: %s -> %s/", name, source, dest_path
)
if path.exists(dest_path) and path.exists(model_index):
logger.info("ONNX model already exists, skipping conversion")
if "hash" in model and not path.exists(model_hash):
logger.info("ONNX model does not have hash file, adding one")
with open(model_hash, "w") as f:
f.write(model["hash"])
return (False, dest_path)
cache_path = fetch_model(conversion, name, source, format=format)
temp_path = path.join(conversion.cache_path, f"{name}-torch")
pipe_class = CONVERT_PIPELINES.get(pipe_type)
v2, pipe_args = get_model_version(
cache_path, conversion.map_location, size=image_size, version=version
)
if path.isdir(cache_path):
pipeline = pipe_class.from_pretrained(cache_path, **pipe_args)
else:
pipeline = pipe_class.from_single_file(cache_path, **pipe_args)
if replace_vae is not None:
vae_path = path.join(conversion.model_path, replace_vae)
vae_file = check_ext(vae_path, RESOLVE_FORMATS)
if vae_file[0]:
logger.debug("loading VAE from single tensor file: %s", vae_path)
pipeline.vae = AutoencoderKL.from_single_file(vae_path)
else:
logger.debug("loading VAE from single tensor file: %s", vae_path)
pipeline.vae = AutoencoderKL.from_pretrained(replace_vae)
if is_torch_2_0:
pipeline.unet.set_attn_processor(AttnProcessor())
pipeline.vae.set_attn_processor(AttnProcessor())
optimize_pipeline(conversion, pipeline)
if path.exists(temp_path):
logger.debug("torch model already exists for %s: %s", source, temp_path)
else:
logger.debug("exporting torch model for %s: %s", source, temp_path)
pipeline.save_pretrained(temp_path)
main_export(
temp_path,
output=dest_path,
task="stable-diffusion",
device=device,
fp16=conversion.has_optimization(
"torch-fp16"
), # optimum's fp16 mode only works on CUDA or ROCm
framework="pt",
)
if "hash" in model:
logger.debug("adding hash file to ONNX model")
with open(model_hash, "w") as f:
f.write(model["hash"])
if conversion.half:
unet_path = path.join(dest_path, "unet", ONNX_MODEL)
infer_shapes_path(unet_path)
unet = load_model(unet_path)
opt_model = convert_float_to_float16(
unet,
disable_shape_infer=True,
force_fp16_initializers=True,
keep_io_types=True,
op_block_list=["Attention", "MultiHeadAttention"],
)
save_model(
opt_model,
unet_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
)
return (True, dest_path)

View File

@ -115,4 +115,4 @@ def convert_diffusion_diffusers_xl(
location="weights.pb", location="weights.pb",
) )
return False, dest_path return (True, dest_path)