1
0
Fork 0

serialize models before loading

This commit is contained in:
Sean Sube 2023-09-24 18:04:23 -05:00
parent a3a04fd1f4
commit e1c2ae5b1b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 3 additions and 3 deletions

View File

@ -383,7 +383,7 @@ def load_text_encoders(
# session for te1 # session for te1
text_encoder_session = InferenceSession( text_encoder_session = InferenceSession(
text_encoder, text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")], providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts, sess_options=text_encoder_opts,
) )
@ -392,7 +392,7 @@ def load_text_encoders(
# session for te2 # session for te2
text_encoder_2_session = InferenceSession( text_encoder_2_session = InferenceSession(
text_encoder_2, text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")], providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts, sess_options=text_encoder_2_opts,
) )
@ -402,7 +402,7 @@ def load_text_encoders(
# session for te # session for te
components["text_encoder"] = OnnxRuntimeModel( components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
text_encoder, text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"), provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts, sess_options=text_encoder_opts,
) )