1
0
Fork 0

fix(api): limit fp16 ops for v2.1 models (#364)

This commit is contained in:
Sean Sube 2023-04-30 08:49:34 -05:00
parent ea9d15fdd2
commit 572a5159ad
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
2 changed files with 12 additions and 2 deletions

View File

@ -110,6 +110,7 @@ def convert_diffusion_diffusers_cnet(
num_tokens,
text_hidden_size,
unet: Optional[Any] = None,
v2: Optional[bool] = False,
):
# CNet
if unet is not None:
@ -218,6 +219,7 @@ def convert_diffusion_diffusers_cnet(
opset=conversion.opset,
half=conversion.half,
external_data=True, # UNet is > 2GB, so the weights need to be split
v2=v2,
)
cnet_model_path = str(cnet_path.absolute().as_posix())
cnet_dir = path.dirname(cnet_model_path)
@ -289,7 +291,7 @@ def convert_diffusion_diffusers(
return (False, dest_path)
pipe_class = available_pipelines.get(pipe_type)
_v2, pipe_args = get_model_version(
v2, pipe_args = get_model_version(
source, conversion.map_location, size=image_size, version=version
)
@ -397,6 +399,7 @@ def convert_diffusion_diffusers(
opset=conversion.opset,
half=conversion.half,
external_data=True,
v2=v2,
)
unet_model_path = str(unet_path.absolute().as_posix())
unet_dir = path.dirname(unet_model_path)
@ -429,6 +432,7 @@ def convert_diffusion_diffusers(
num_tokens,
text_hidden_size,
unet=pipeline.unet,
v2=v2,
)
else:
logger.debug("skipping CNet for single-VAE model")

View File

@ -287,6 +287,7 @@ def onnx_export(
opset,
half=False,
external_data=False,
v2=False,
):
"""
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
@ -305,6 +306,10 @@ def onnx_export(
opset_version=opset,
)
op_block_list=None
if v2:
op_block_list=["Attention", "MultiHeadAttention"]
if half:
logger.info("converting model to fp16 internally: %s", output_file)
infer_shapes_path(output_file)
@ -312,8 +317,9 @@ def onnx_export(
opt_model = convert_float_to_float16(
base_model,
disable_shape_infer=True,
keep_io_types=True,
force_fp16_initializers=True,
keep_io_types=True,
op_block_list=op_block_list,
)
save_model(
opt_model,