diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 2773922c..60ef7090 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -195,9 +195,7 @@ def load_pipeline( ) components.update(encoder_components) - unet_components = load_unet( - server, device, model, loras, unet_type, params - ) + unet_components = load_unet(server, device, model, loras, unet_type, params) components.update(unet_components) vae_components = load_vae(server, device, model, params) @@ -330,23 +328,33 @@ def load_text_encoders( # should be pretty small and should not need external data if loras is None or len(loras) == 0: - # TODO: handle XL encoders - components["text_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - text_encoder.SerializeToString(), - provider=device.ort_provider("text-encoder"), - sess_options=device.sess_options(), - ) - ) - if params.is_xl(): + text_encoder = path.join(model, "text_encoder", ONNX_MODEL) + text_encoder_opts = device.sess_options(cache=False) + text_encoder_session = InferenceSession( + text_encoder.SerializeToString(), + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_opts, + ) + + text_encoder_session._model_path = path.join(model, "text_encoder") + + text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) + text_encoder_2_opts = device.sess_options(cache=False) text_encoder_2_session = InferenceSession( text_encoder_2.SerializeToString(), providers=[device.ort_provider("text-encoder")], sess_options=text_encoder_2_opts, ) text_encoder_2_session._model_path = path.join(model, "text_encoder_2") - components["text_encoder_2_session"] = text_encoder_2_session + else: + components["text_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + text_encoder.SerializeToString(), + provider=device.ort_provider("text-encoder"), + sess_options=device.sess_options(), + ) + ) else: # blend and load text encoder @@ -498,9 +506,7 @@ def load_vae(server, device, model, params): provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) - components[ - "vae_decoder_session" - ]._model_path = vae_decoder + components["vae_decoder_session"]._model_path = vae_decoder logger.debug("loading VAE encoder from %s", vae_encoder) components["vae_encoder_session"] = OnnxRuntimeModel.load_model( @@ -508,9 +514,7 @@ def load_vae(server, device, model, params): provider=device.ort_provider("vae"), sess_options=device.sess_options(), ) - components[ - "vae_encoder_session" - ]._model_path = vae_encoder + components["vae_encoder_session"]._model_path = vae_encoder else: logger.debug("loading VAE decoder from %s", vae_decoder)