1
0
Fork 0

fix text encoder loading

This commit is contained in:
Sean Sube 2023-09-24 15:02:39 -05:00
parent 0ecae65f88
commit 85b4245cef
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 5 additions and 4 deletions

View File

@ -328,11 +328,12 @@ def load_text_encoders(
# should be pretty small and should not need external data # should be pretty small and should not need external data
if loras is None or len(loras) == 0: if loras is None or len(loras) == 0:
if params.is_xl():
text_encoder = path.join(model, "text_encoder", ONNX_MODEL) text_encoder = path.join(model, "text_encoder", ONNX_MODEL)
if params.is_xl():
text_encoder_opts = device.sess_options(cache=False) text_encoder_opts = device.sess_options(cache=False)
text_encoder_session = InferenceSession( text_encoder_session = InferenceSession(
text_encoder.SerializeToString(), text_encoder,
providers=[device.ort_provider("text-encoder")], providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts, sess_options=text_encoder_opts,
) )
@ -342,7 +343,7 @@ def load_text_encoders(
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2_opts = device.sess_options(cache=False) text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_session = InferenceSession( text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(), text_encoder_2,
providers=[device.ort_provider("text-encoder")], providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts, sess_options=text_encoder_2_opts,
) )
@ -350,7 +351,7 @@ def load_text_encoders(
else: else:
components["text_encoder"] = OnnxRuntimeModel( components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(), text_encoder,
provider=device.ort_provider("text-encoder"), provider=device.ort_provider("text-encoder"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )