set model path for VAE sessions
This commit is contained in:
parent
d4b013068d
commit
4ccdedba89
|
@ -376,6 +376,7 @@ def load_pipeline(
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
|
components["vae_decoder_session"]._model_path = vae_decoder
|
||||||
|
|
||||||
logger.debug("loading VAE encoder from %s", vae_encoder)
|
logger.debug("loading VAE encoder from %s", vae_encoder)
|
||||||
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
components["vae_encoder_session"] = OnnxRuntimeModel.load_model(
|
||||||
|
@ -383,6 +384,7 @@ def load_pipeline(
|
||||||
provider=device.ort_provider("vae"),
|
provider=device.ort_provider("vae"),
|
||||||
sess_options=device.sess_options(),
|
sess_options=device.sess_options(),
|
||||||
)
|
)
|
||||||
|
components["vae_encoder_session"]._model_path = vae_encoder
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.debug("loading VAE decoder from %s", vae_decoder)
|
logger.debug("loading VAE decoder from %s", vae_decoder)
|
||||||
|
|
Loading…
Reference in New Issue