move text encoder 2 loading
This commit is contained in:
parent
6b6f63564e
commit
d11b37f0b2
|
@ -369,6 +369,32 @@ def load_text_encoders(
|
||||||
)
|
)
|
||||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||||
components["text_encoder_session"] = text_encoder_session
|
components["text_encoder_session"] = text_encoder_session
|
||||||
|
|
||||||
|
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||||
|
text_encoder_2 = blend_loras(
|
||||||
|
server,
|
||||||
|
text_encoder_2,
|
||||||
|
list(zip(lora_models, lora_weights)),
|
||||||
|
"text_encoder",
|
||||||
|
2,
|
||||||
|
params.is_xl(),
|
||||||
|
)
|
||||||
|
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
||||||
|
text_encoder_2
|
||||||
|
)
|
||||||
|
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
|
||||||
|
text_encoder_2_opts = device.sess_options(cache=False)
|
||||||
|
text_encoder_2_opts.add_external_initializers(
|
||||||
|
list(text_encoder_2_names), list(text_encoder_2_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder_2_session = InferenceSession(
|
||||||
|
text_encoder_2.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder_2_opts,
|
||||||
|
)
|
||||||
|
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||||
|
components["text_encoder_2_session"] = text_encoder_2_session
|
||||||
else:
|
else:
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
|
@ -378,33 +404,6 @@ def load_text_encoders(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.is_xl():
|
|
||||||
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
|
||||||
text_encoder_2 = blend_loras(
|
|
||||||
server,
|
|
||||||
text_encoder_2,
|
|
||||||
list(zip(lora_models, lora_weights)),
|
|
||||||
"text_encoder",
|
|
||||||
2,
|
|
||||||
params.is_xl(),
|
|
||||||
)
|
|
||||||
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
|
||||||
text_encoder_2
|
|
||||||
)
|
|
||||||
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
|
|
||||||
text_encoder_2_opts = device.sess_options(cache=False)
|
|
||||||
text_encoder_2_opts.add_external_initializers(
|
|
||||||
list(text_encoder_2_names), list(text_encoder_2_values)
|
|
||||||
)
|
|
||||||
|
|
||||||
text_encoder_2_session = InferenceSession(
|
|
||||||
text_encoder_2.SerializeToString(),
|
|
||||||
providers=[device.ort_provider("text-encoder")],
|
|
||||||
sess_options=text_encoder_2_opts,
|
|
||||||
)
|
|
||||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
|
||||||
components["text_encoder_2_session"] = text_encoder_2_session
|
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue