serialize models before loading
This commit is contained in:
parent
a3a04fd1f4
commit
e1c2ae5b1b
|
@ -383,7 +383,7 @@ def load_text_encoders(
|
||||||
|
|
||||||
# session for te1
|
# session for te1
|
||||||
text_encoder_session = InferenceSession(
|
text_encoder_session = InferenceSession(
|
||||||
text_encoder,
|
text_encoder.SerializeToString(),
|
||||||
providers=[device.ort_provider("text-encoder")],
|
providers=[device.ort_provider("text-encoder")],
|
||||||
sess_options=text_encoder_opts,
|
sess_options=text_encoder_opts,
|
||||||
)
|
)
|
||||||
|
@ -392,7 +392,7 @@ def load_text_encoders(
|
||||||
|
|
||||||
# session for te2
|
# session for te2
|
||||||
text_encoder_2_session = InferenceSession(
|
text_encoder_2_session = InferenceSession(
|
||||||
text_encoder_2,
|
text_encoder_2.SerializeToString(),
|
||||||
providers=[device.ort_provider("text-encoder")],
|
providers=[device.ort_provider("text-encoder")],
|
||||||
sess_options=text_encoder_2_opts,
|
sess_options=text_encoder_2_opts,
|
||||||
)
|
)
|
||||||
|
@ -402,7 +402,7 @@ def load_text_encoders(
|
||||||
# session for te
|
# session for te
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
text_encoder,
|
text_encoder.SerializeToString(),
|
||||||
provider=device.ort_provider("text-encoder"),
|
provider=device.ort_provider("text-encoder"),
|
||||||
sess_options=text_encoder_opts,
|
sess_options=text_encoder_opts,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue