fix XL text encoder loading
This commit is contained in:
parent
5d3a7d77a5
commit
539140909b
|
@ -195,9 +195,7 @@ def load_pipeline(
|
||||||
)
|
)
|
||||||
components.update(encoder_components)
|
components.update(encoder_components)
|
||||||
|
|
||||||
unet_components = load_unet(
|
unet_components = load_unet(server, device, model, loras, unet_type, params)
|
||||||
server, device, model, loras, unet_type, params
|
|
||||||
)
|
|
||||||
components.update(unet_components)
|
components.update(unet_components)
|
||||||
|
|
||||||
vae_components = load_vae(server, device, model, params)
|
vae_components = load_vae(server, device, model, params)
|
||||||
|
@ -330,23 +328,33 @@ def load_text_encoders(
|
||||||
|
|
||||||
# should be pretty small and should not need external data
|
# should be pretty small and should not need external data
|
||||||
if loras is None or len(loras) == 0:
|
if loras is None or len(loras) == 0:
|
||||||
# TODO: handle XL encoders
|
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
text_encoder.SerializeToString(),
|
|
||||||
provider=device.ort_provider("text-encoder"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.is_xl():
|
if params.is_xl():
|
||||||
|
text_encoder = path.join(model, "text_encoder", ONNX_MODEL)
|
||||||
|
text_encoder_opts = device.sess_options(cache=False)
|
||||||
|
text_encoder_session = InferenceSession(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder_opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||||
|
|
||||||
|
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
||||||
|
text_encoder_2_opts = device.sess_options(cache=False)
|
||||||
text_encoder_2_session = InferenceSession(
|
text_encoder_2_session = InferenceSession(
|
||||||
text_encoder_2.SerializeToString(),
|
text_encoder_2.SerializeToString(),
|
||||||
providers=[device.ort_provider("text-encoder")],
|
providers=[device.ort_provider("text-encoder")],
|
||||||
sess_options=text_encoder_2_opts,
|
sess_options=text_encoder_2_opts,
|
||||||
)
|
)
|
||||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||||
components["text_encoder_2_session"] = text_encoder_2_session
|
else:
|
||||||
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
|
OnnxRuntimeModel.load_model(
|
||||||
|
text_encoder.SerializeToString(),
|
||||||
|
provider=device.ort_provider("text-encoder"),
|
||||||
|
sess_options=device.sess_options(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# blend and load text encoder
|
# blend and load text encoder
|
||||||
|
@ -498,9 +506,7 @@ def load_vae(server, device, model, params):
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
components[
|
components["vae_decoder_session"]._model_path = vae_decoder
|
||||||
"vae_decoder_session"
|
|
||||||
]._model_path = vae_decoder
|
|
||||||
|
|
||||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
|
@ -508,9 +514,7 @@ def load_vae(server, device, model, params):
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
components[
|
components["vae_encoder_session"]._model_path = vae_encoder
|
||||||
"vae_encoder_session"
|
|
||||||
]._model_path = vae_encoder
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
|
|
Loading…
Reference in New Issue