diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 4e4c8c90..3e0c2920 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -369,6 +369,32 @@ def load_text_encoders( ) text_encoder_session._model_path = path.join(model, "text_encoder") components["text_encoder_session"] = text_encoder_session + + text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) + text_encoder_2 = blend_loras( + server, + text_encoder_2, + list(zip(lora_models, lora_weights)), + "text_encoder", + 2, + params.is_xl(), + ) + (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( + text_encoder_2 + ) + text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) + text_encoder_2_opts = device.sess_options(cache=False) + text_encoder_2_opts.add_external_initializers( + list(text_encoder_2_names), list(text_encoder_2_values) + ) + + text_encoder_2_session = InferenceSession( + text_encoder_2.SerializeToString(), + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_2_opts, + ) + text_encoder_2_session._model_path = path.join(model, "text_encoder_2") + components["text_encoder_2_session"] = text_encoder_2_session else: components["text_encoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( @@ -378,33 +404,6 @@ def load_text_encoders( ) ) - if params.is_xl(): - text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) - text_encoder_2 = blend_loras( - server, - text_encoder_2, - list(zip(lora_models, lora_weights)), - "text_encoder", - 2, - params.is_xl(), - ) - (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( - text_encoder_2 - ) - text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) - text_encoder_2_opts = device.sess_options(cache=False) - text_encoder_2_opts.add_external_initializers( - list(text_encoder_2_names), list(text_encoder_2_values) - ) - - text_encoder_2_session = InferenceSession( - text_encoder_2.SerializeToString(), - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_2_opts, - ) - text_encoder_2_session._model_path = path.join(model, "text_encoder_2") - components["text_encoder_2_session"] = text_encoder_2_session - return components