1
0
Fork 0

reload model from proto file before converting

This commit is contained in:
Sean Sube 2023-02-25 08:12:10 -06:00
parent 2210ee849b
commit 7e65e21410
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 12 additions and 7 deletions

View File

@ -22,7 +22,7 @@ from diffusers import (
OnnxStableDiffusionPipeline, OnnxStableDiffusionPipeline,
StableDiffusionPipeline, StableDiffusionPipeline,
) )
from onnx import load, save_model from onnx import load_model, save_model
from onnxruntime.transformers.float16 import convert_float_to_float16 from onnxruntime.transformers.float16 import convert_float_to_float16
from torch.onnx import export from torch.onnx import export
@ -48,15 +48,13 @@ def onnx_export(
""" """
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
""" """
if half:
model = convert_float_to_float16(model, keep_io_types=True, force_fp16_initializers=True)
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
output_file = output_path.as_posix()
export( export(
model, model,
model_args, model_args,
f=output_path.as_posix(), f=output_file,
input_names=ordered_input_names, input_names=ordered_input_names,
output_names=output_names, output_names=output_names,
dynamic_axes=dynamic_axes, dynamic_axes=dynamic_axes,
@ -64,6 +62,13 @@ def onnx_export(
opset_version=opset, opset_version=opset,
) )
if half:
logger.info("converting model to FP16 internally")
base_model = load_model(output_file)
opt_model = convert_float_to_float16(base_model, keep_io_types=True, force_fp16_initializers=True)
save_model(opt_model, f"{output_file}-optimized")
@torch.no_grad() @torch.no_grad()
def convert_diffusion_diffusers( def convert_diffusion_diffusers(
@ -171,7 +176,7 @@ def convert_diffusion_diffusers(
) )
unet_model_path = str(unet_path.absolute().as_posix()) unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = path.dirname(unet_model_path) unet_dir = path.dirname(unet_model_path)
unet = load(unet_model_path) unet = load_model(unet_model_path)
# clean up existing tensor files # clean up existing tensor files
rmtree(unet_dir) rmtree(unet_dir)
mkdir(unet_dir) mkdir(unet_dir)