1
0
Fork 0

pass XL components as ORT sessions

This commit is contained in:
Sean Sube 2023-09-03 16:09:01 -05:00
parent 9999a1bf56
commit 2047f7d8cf
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 28 additions and 29 deletions

View File

@ -223,6 +223,7 @@ def load_pipeline(
# should be pretty small and should not need external data
if loras is None or len(loras) == 0:
# TODO: handle XL encoders
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
@ -267,9 +268,7 @@ def load_pipeline(
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder"] = ORTModelTextEncoder(
text_encoder_session, text_encoder
)
components["text_encoder_session"] = text_encoder_session
else:
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
@ -280,33 +279,31 @@ def load_pipeline(
)
if params.is_xl():
text_encoder2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder2 = blend_loras(
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2 = blend_loras(
server,
text_encoder2,
text_encoder_2,
list(zip(lora_models, lora_weights)),
"text_encoder",
2,
params.is_xl()
)
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
text_encoder2
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
text_encoder_2
)
text_encoder2_names, text_encoder2_values = zip(*text_encoder2_data)
text_encoder2_opts = device.sess_options(cache=False)
text_encoder2_opts.add_external_initializers(
list(text_encoder2_names), list(text_encoder2_values)
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_opts.add_external_initializers(
list(text_encoder_2_names), list(text_encoder_2_values)
)
text_encoder2_session = InferenceSession(
text_encoder2.SerializeToString(),
text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder2_opts,
)
text_encoder2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2"] = ORTModelTextEncoder(
text_encoder2_session, text_encoder2
sess_options=text_encoder_2_opts,
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2_session"] = text_encoder_2_session
# blend and load unet
unet = path.join(model, unet_type, ONNX_MODEL)
@ -329,7 +326,7 @@ def load_pipeline(
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet"] = ORTModelUnet(unet_session, unet_model)
components["unet_session"] = unet_session
else:
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
@ -404,27 +401,29 @@ def load_pipeline(
# make sure XL models are actually being used
# TODO: why is this needed?
if "text_encoder" in components:
if "text_encoder_session" in components:
logger.info(
"text encoder matches: %s, %s",
pipe.text_encoder == components["text_encoder"],
pipe.text_encoder.session == components["text_encoder_session"],
type(pipe.text_encoder),
)
pipe.text_encoder = components["text_encoder"]
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
if "text_encoder_2" in components:
if "text_encoder_2_session" in components:
logger.info(
"text encoder 2 matches: %s, %s",
pipe.text_encoder_2 == components["text_encoder_2"],
pipe.text_encoder_2.session == components["text_encoder_2_session"],
type(pipe.text_encoder_2),
)
pipe.text_encoder_2 = components["text_encoder_2"]
pipe.text_encoder_2 = ORTModelTextEncoder(text_encoder_2_session, text_encoder_2)
if "unet" in components:
if "unet_session" in components:
logger.info(
"unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet)
"unet matches: %s, %s",
pipe.unet.session == components["unet_session"],
type(pipe.unet),
)
pipe.unet = components["unet"]
pipe.unet = ORTModelUnet(unet_session, unet_model)
if not server.show_progress:
pipe.set_progress_bar_config(disable=True)