fix text encoder loading
This commit is contained in:
parent
0ecae65f88
commit
85b4245cef
|
@ -328,11 +328,12 @@ def load_text_encoders(
|
|||
|
||||
# should be pretty small and should not need external data
|
||||
if loras is None or len(loras) == 0:
|
||||
text_encoder = path.join(model, "text_encoder", ONNX_MODEL)
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder = path.join(model, "text_encoder", ONNX_MODEL)
|
||||
text_encoder_opts = device.sess_options(cache=False)
|
||||
text_encoder_session = InferenceSession(
|
||||
text_encoder.SerializeToString(),
|
||||
text_encoder,
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
sess_options=text_encoder_opts,
|
||||
)
|
||||
|
@ -342,7 +343,7 @@ def load_text_encoders(
|
|||
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||
text_encoder_2_opts = device.sess_options(cache=False)
|
||||
text_encoder_2_session = InferenceSession(
|
||||
text_encoder_2.SerializeToString(),
|
||||
text_encoder_2,
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
sess_options=text_encoder_2_opts,
|
||||
)
|
||||
|
@ -350,7 +351,7 @@ def load_text_encoders(
|
|||
else:
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder.SerializeToString(),
|
||||
text_encoder,
|
||||
provider=device.ort_provider("text-encoder"),
|
||||
sess_options=device.sess_options(),
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue