1
0
Fork 0

fix XL text encoder loading

This commit is contained in:
Sean Sube 2023-09-24 10:04:44 -05:00
parent 5d3a7d77a5
commit 539140909b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 23 additions and 19 deletions

View File

@ -195,9 +195,7 @@ def load_pipeline(
) )
components.update(encoder_components) components.update(encoder_components)
unet_components = load_unet( unet_components = load_unet(server, device, model, loras, unet_type, params)
server, device, model, loras, unet_type, params
)
components.update(unet_components) components.update(unet_components)
vae_components = load_vae(server, device, model, params) vae_components = load_vae(server, device, model, params)
@ -330,23 +328,33 @@ def load_text_encoders(
# should be pretty small and should not need external data # should be pretty small and should not need external data
if loras is None or len(loras) == 0: if loras is None or len(loras) == 0:
# TODO: handle XL encoders
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=device.sess_options(),
)
)
if params.is_xl(): if params.is_xl():
text_encoder = path.join(model, "text_encoder", ONNX_MODEL)
text_encoder_opts = device.sess_options(cache=False)
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_session = InferenceSession( text_encoder_2_session = InferenceSession(
text_encoder_2.SerializeToString(), text_encoder_2.SerializeToString(),
providers=[device.ort_provider("text-encoder")], providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts, sess_options=text_encoder_2_opts,
) )
text_encoder_2_session._model_path = path.join(model, "text_encoder_2") text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2_session"] = text_encoder_2_session else:
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder.SerializeToString(),
provider=device.ort_provider("text-encoder"),
sess_options=device.sess_options(),
)
)
else: else:
# blend and load text encoder # blend and load text encoder
@ -498,9 +506,7 @@ def load_vae(server, device, model, params):
provider=device.ort_provider("vae"), provider=device.ort_provider("vae"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
components[ components["vae_decoder_session"]._model_path = vae_decoder
"vae_decoder_session"
]._model_path = vae_decoder
logger.debug("loading VAE encoder from %s", vae_encoder) logger.debug("loading VAE encoder from %s", vae_encoder)
components["vae_encoder_session"] = OnnxRuntimeModel.load_model( components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
@ -508,9 +514,7 @@ def load_vae(server, device, model, params):
provider=device.ort_provider("vae"), provider=device.ort_provider("vae"),
sess_options=device.sess_options(), sess_options=device.sess_options(),
) )
components[ components["vae_encoder_session"]._model_path = vae_encoder
"vae_encoder_session"
]._model_path = vae_encoder
else: else:
logger.debug("loading VAE decoder from %s", vae_decoder) logger.debug("loading VAE decoder from %s", vae_decoder)