reload model from proto file before converting
This commit is contained in:
parent
2210ee849b
commit
7e65e21410
|
@ -22,7 +22,7 @@ from diffusers import (
|
||||||
OnnxStableDiffusionPipeline,
|
OnnxStableDiffusionPipeline,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
)
|
)
|
||||||
from onnx import load, save_model
|
from onnx import load_model, save_model
|
||||||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||||
from torch.onnx import export
|
from torch.onnx import export
|
||||||
|
|
||||||
|
@ -48,15 +48,13 @@ def onnx_export(
|
||||||
"""
|
"""
|
||||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if half:
|
|
||||||
model = convert_float_to_float16(model, keep_io_types=True, force_fp16_initializers=True)
|
|
||||||
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_file = output_path.as_posix()
|
||||||
|
|
||||||
export(
|
export(
|
||||||
model,
|
model,
|
||||||
model_args,
|
model_args,
|
||||||
f=output_path.as_posix(),
|
f=output_file,
|
||||||
input_names=ordered_input_names,
|
input_names=ordered_input_names,
|
||||||
output_names=output_names,
|
output_names=output_names,
|
||||||
dynamic_axes=dynamic_axes,
|
dynamic_axes=dynamic_axes,
|
||||||
|
@ -64,6 +62,13 @@ def onnx_export(
|
||||||
opset_version=opset,
|
opset_version=opset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if half:
|
||||||
|
logger.info("converting model to FP16 internally")
|
||||||
|
base_model = load_model(output_file)
|
||||||
|
opt_model = convert_float_to_float16(base_model, keep_io_types=True, force_fp16_initializers=True)
|
||||||
|
save_model(opt_model, f"{output_file}-optimized")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert_diffusion_diffusers(
|
def convert_diffusion_diffusers(
|
||||||
|
@ -171,7 +176,7 @@ def convert_diffusion_diffusers(
|
||||||
)
|
)
|
||||||
unet_model_path = str(unet_path.absolute().as_posix())
|
unet_model_path = str(unet_path.absolute().as_posix())
|
||||||
unet_dir = path.dirname(unet_model_path)
|
unet_dir = path.dirname(unet_model_path)
|
||||||
unet = load(unet_model_path)
|
unet = load_model(unet_model_path)
|
||||||
# clean up existing tensor files
|
# clean up existing tensor files
|
||||||
rmtree(unet_dir)
|
rmtree(unet_dir)
|
||||||
mkdir(unet_dir)
|
mkdir(unet_dir)
|
||||||
|
|
Loading…
Reference in New Issue