fix(api): limit fp16 ops for v2.1 models (#364)
This commit is contained in:
parent
ea9d15fdd2
commit
572a5159ad
|
@ -110,6 +110,7 @@ def convert_diffusion_diffusers_cnet(
|
|||
num_tokens,
|
||||
text_hidden_size,
|
||||
unet: Optional[Any] = None,
|
||||
v2: Optional[bool] = False,
|
||||
):
|
||||
# CNet
|
||||
if unet is not None:
|
||||
|
@ -218,6 +219,7 @@ def convert_diffusion_diffusers_cnet(
|
|||
opset=conversion.opset,
|
||||
half=conversion.half,
|
||||
external_data=True, # UNet is > 2GB, so the weights need to be split
|
||||
v2=v2,
|
||||
)
|
||||
cnet_model_path = str(cnet_path.absolute().as_posix())
|
||||
cnet_dir = path.dirname(cnet_model_path)
|
||||
|
@ -289,7 +291,7 @@ def convert_diffusion_diffusers(
|
|||
return (False, dest_path)
|
||||
|
||||
pipe_class = available_pipelines.get(pipe_type)
|
||||
_v2, pipe_args = get_model_version(
|
||||
v2, pipe_args = get_model_version(
|
||||
source, conversion.map_location, size=image_size, version=version
|
||||
)
|
||||
|
||||
|
@ -397,6 +399,7 @@ def convert_diffusion_diffusers(
|
|||
opset=conversion.opset,
|
||||
half=conversion.half,
|
||||
external_data=True,
|
||||
v2=v2,
|
||||
)
|
||||
unet_model_path = str(unet_path.absolute().as_posix())
|
||||
unet_dir = path.dirname(unet_model_path)
|
||||
|
@ -429,6 +432,7 @@ def convert_diffusion_diffusers(
|
|||
num_tokens,
|
||||
text_hidden_size,
|
||||
unet=pipeline.unet,
|
||||
v2=v2,
|
||||
)
|
||||
else:
|
||||
logger.debug("skipping CNet for single-VAE model")
|
||||
|
|
|
@ -287,6 +287,7 @@ def onnx_export(
|
|||
opset,
|
||||
half=False,
|
||||
external_data=False,
|
||||
v2=False,
|
||||
):
|
||||
"""
|
||||
From https://github.com/huggingface/diffusers/blob/main/scripts/convert_stable_diffusion_checkpoint_to_onnx.py
|
||||
|
@ -305,6 +306,10 @@ def onnx_export(
|
|||
opset_version=opset,
|
||||
)
|
||||
|
||||
op_block_list=None
|
||||
if v2:
|
||||
op_block_list=["Attention", "MultiHeadAttention"]
|
||||
|
||||
if half:
|
||||
logger.info("converting model to fp16 internally: %s", output_file)
|
||||
infer_shapes_path(output_file)
|
||||
|
@ -312,8 +317,9 @@ def onnx_export(
|
|||
opt_model = convert_float_to_float16(
|
||||
base_model,
|
||||
disable_shape_infer=True,
|
||||
keep_io_types=True,
|
||||
force_fp16_initializers=True,
|
||||
keep_io_types=True,
|
||||
op_block_list=op_block_list,
|
||||
)
|
||||
save_model(
|
||||
opt_model,
|
||||
|
|
Loading…
Reference in New Issue