From ebe813d0351db32af3a74bba61cd4d2b26c60748 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sat, 23 Dec 2023 22:18:35 -0600 Subject: [PATCH] fix(api): change weights filename for new models to match optimum --- api/onnx_web/constants.py | 2 +- api/onnx_web/convert/diffusion/diffusion.py | 2 +- api/onnx_web/convert/diffusion/diffusion_xl.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/onnx_web/constants.py b/api/onnx_web/constants.py index 0eb6c039..6f8b3d43 100644 --- a/api/onnx_web/constants.py +++ b/api/onnx_web/constants.py @@ -1,5 +1,5 @@ ONNX_MODEL = "model.onnx" -ONNX_WEIGHTS = "weights.pb" +ONNX_WEIGHTS = "model.onnx_data" LATENT_FACTOR = 8 LATENT_CHANNELS = 4 diff --git a/api/onnx_web/convert/diffusion/diffusion.py b/api/onnx_web/convert/diffusion/diffusion.py index 71de8003..bdcd90bc 100644 --- a/api/onnx_web/convert/diffusion/diffusion.py +++ b/api/onnx_web/convert/diffusion/diffusion.py @@ -865,7 +865,7 @@ def convert_diffusion_diffusers_optimum( unet_path, save_as_external_data=True, all_tensors_to_one_file=True, - location="weights.pb", + location=ONNX_WEIGHTS, ) return (True, dest_path) diff --git a/api/onnx_web/convert/diffusion/diffusion_xl.py b/api/onnx_web/convert/diffusion/diffusion_xl.py index 92947e92..2ec60f93 100644 --- a/api/onnx_web/convert/diffusion/diffusion_xl.py +++ b/api/onnx_web/convert/diffusion/diffusion_xl.py @@ -9,7 +9,7 @@ from onnx.shape_inference import infer_shapes_path from onnxruntime.transformers.float16 import convert_float_to_float16 from optimum.exporters.onnx import main_export -from ...constants import ONNX_MODEL +from ...constants import ONNX_MODEL, ONNX_WEIGHTS from ..client import fetch_model from ..utils import RESOLVE_FORMATS, ConversionContext, check_ext @@ -112,7 +112,7 @@ def convert_diffusion_diffusers_xl( unet_path, save_as_external_data=True, all_tensors_to_one_file=True, - location="weights.pb", + location=ONNX_WEIGHTS, ) return (True, dest_path)