pass XL components as ORT sessions
This commit is contained in:
parent
9999a1bf56
commit
2047f7d8cf
|
@ -223,6 +223,7 @@ def load_pipeline(
|
|||
|
||||
# should be pretty small and should not need external data
|
||||
if loras is None or len(loras) == 0:
|
||||
# TODO: handle XL encoders
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
text_encoder.SerializeToString(),
|
||||
|
@ -267,9 +268,7 @@ def load_pipeline(
|
|||
sess_options=text_encoder_opts,
|
||||
)
|
||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||
components["text_encoder"] = ORTModelTextEncoder(
|
||||
text_encoder_session, text_encoder
|
||||
)
|
||||
components["text_encoder_session"] = text_encoder_session
|
||||
else:
|
||||
components["text_encoder"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
@ -280,33 +279,31 @@ def load_pipeline(
|
|||
)
|
||||
|
||||
if params.is_xl():
|
||||
text_encoder2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||
text_encoder2 = blend_loras(
|
||||
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||
text_encoder_2 = blend_loras(
|
||||
server,
|
||||
text_encoder2,
|
||||
text_encoder_2,
|
||||
list(zip(lora_models, lora_weights)),
|
||||
"text_encoder",
|
||||
2,
|
||||
params.is_xl()
|
||||
)
|
||||
(text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
|
||||
text_encoder2
|
||||
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
||||
text_encoder_2
|
||||
)
|
||||
text_encoder2_names, text_encoder2_values = zip(*text_encoder2_data)
|
||||
text_encoder2_opts = device.sess_options(cache=False)
|
||||
text_encoder2_opts.add_external_initializers(
|
||||
list(text_encoder2_names), list(text_encoder2_values)
|
||||
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
|
||||
text_encoder_2_opts = device.sess_options(cache=False)
|
||||
text_encoder_2_opts.add_external_initializers(
|
||||
list(text_encoder_2_names), list(text_encoder_2_values)
|
||||
)
|
||||
|
||||
text_encoder2_session = InferenceSession(
|
||||
text_encoder2.SerializeToString(),
|
||||
text_encoder_2_session = InferenceSession(
|
||||
text_encoder_2.SerializeToString(),
|
||||
providers=[device.ort_provider("text-encoder")],
|
||||
sess_options=text_encoder2_opts,
|
||||
)
|
||||
text_encoder2_session._model_path = path.join(model, "text_encoder_2")
|
||||
components["text_encoder_2"] = ORTModelTextEncoder(
|
||||
text_encoder2_session, text_encoder2
|
||||
sess_options=text_encoder_2_opts,
|
||||
)
|
||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||
components["text_encoder_2_session"] = text_encoder_2_session
|
||||
|
||||
# blend and load unet
|
||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
||||
|
@ -329,7 +326,7 @@ def load_pipeline(
|
|||
sess_options=unet_opts,
|
||||
)
|
||||
unet_session._model_path = path.join(model, "unet")
|
||||
components["unet"] = ORTModelUnet(unet_session, unet_model)
|
||||
components["unet_session"] = unet_session
|
||||
else:
|
||||
components["unet"] = OnnxRuntimeModel(
|
||||
OnnxRuntimeModel.load_model(
|
||||
|
@ -404,27 +401,29 @@ def load_pipeline(
|
|||
|
||||
# make sure XL models are actually being used
|
||||
# TODO: why is this needed?
|
||||
if "text_encoder" in components:
|
||||
if "text_encoder_session" in components:
|
||||
logger.info(
|
||||
"text encoder matches: %s, %s",
|
||||
pipe.text_encoder == components["text_encoder"],
|
||||
pipe.text_encoder.session == components["text_encoder_session"],
|
||||
type(pipe.text_encoder),
|
||||
)
|
||||
pipe.text_encoder = components["text_encoder"]
|
||||
pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder)
|
||||
|
||||
if "text_encoder_2" in components:
|
||||
if "text_encoder_2_session" in components:
|
||||
logger.info(
|
||||
"text encoder 2 matches: %s, %s",
|
||||
pipe.text_encoder_2 == components["text_encoder_2"],
|
||||
pipe.text_encoder_2.session == components["text_encoder_2_session"],
|
||||
type(pipe.text_encoder_2),
|
||||
)
|
||||
pipe.text_encoder_2 = components["text_encoder_2"]
|
||||
pipe.text_encoder_2 = ORTModelTextEncoder(text_encoder_2_session, text_encoder_2)
|
||||
|
||||
if "unet" in components:
|
||||
if "unet_session" in components:
|
||||
logger.info(
|
||||
"unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet)
|
||||
"unet matches: %s, %s",
|
||||
pipe.unet.session == components["unet_session"],
|
||||
type(pipe.unet),
|
||||
)
|
||||
pipe.unet = components["unet"]
|
||||
pipe.unet = ORTModelUnet(unet_session, unet_model)
|
||||
|
||||
if not server.show_progress:
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
|
Loading…
Reference in New Issue