fix(api): change weights filename for new models to match optimum
This commit is contained in:
parent
14208de393
commit
ebe813d035
|
@ -1,5 +1,5 @@
|
||||||
ONNX_MODEL = "model.onnx"
|
ONNX_MODEL = "model.onnx"
|
||||||
ONNX_WEIGHTS = "weights.pb"
|
ONNX_WEIGHTS = "model.onnx_data"
|
||||||
|
|
||||||
LATENT_FACTOR = 8
|
LATENT_FACTOR = 8
|
||||||
LATENT_CHANNELS = 4
|
LATENT_CHANNELS = 4
|
||||||
|
|
|
@ -865,7 +865,7 @@ def convert_diffusion_diffusers_optimum(
|
||||||
unet_path,
|
unet_path,
|
||||||
save_as_external_data=True,
|
save_as_external_data=True,
|
||||||
all_tensors_to_one_file=True,
|
all_tensors_to_one_file=True,
|
||||||
location="weights.pb",
|
location=ONNX_WEIGHTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (True, dest_path)
|
return (True, dest_path)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from onnx.shape_inference import infer_shapes_path
|
||||||
from onnxruntime.transformers.float16 import convert_float_to_float16
|
from onnxruntime.transformers.float16 import convert_float_to_float16
|
||||||
from optimum.exporters.onnx import main_export
|
from optimum.exporters.onnx import main_export
|
||||||
|
|
||||||
from ...constants import ONNX_MODEL
|
from ...constants import ONNX_MODEL, ONNX_WEIGHTS
|
||||||
from ..client import fetch_model
|
from ..client import fetch_model
|
||||||
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ def convert_diffusion_diffusers_xl(
|
||||||
unet_path,
|
unet_path,
|
||||||
save_as_external_data=True,
|
save_as_external_data=True,
|
||||||
all_tensors_to_one_file=True,
|
all_tensors_to_one_file=True,
|
||||||
location="weights.pb",
|
location=ONNX_WEIGHTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (True, dest_path)
|
return (True, dest_path)
|
||||||
|
|
Loading…
Reference in New Issue