From 2047f7d8cf81b176c59024b29e2d3a53ed10eee9 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 3 Sep 2023 16:09:01 -0500 Subject: [PATCH] pass XL components as ORT sessions --- api/onnx_web/diffusers/load.py | 57 +++++++++++++++++----------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index 95f47fd2..77bb0d1c 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -223,6 +223,7 @@ def load_pipeline( # should be pretty small and should not need external data if loras is None or len(loras) == 0: + # TODO: handle XL encoders components["text_encoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( text_encoder.SerializeToString(), @@ -267,9 +268,7 @@ def load_pipeline( sess_options=text_encoder_opts, ) text_encoder_session._model_path = path.join(model, "text_encoder") - components["text_encoder"] = ORTModelTextEncoder( - text_encoder_session, text_encoder - ) + components["text_encoder_session"] = text_encoder_session else: components["text_encoder"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( @@ -280,33 +279,31 @@ def load_pipeline( ) if params.is_xl(): - text_encoder2 = path.join(model, "text_encoder_2", ONNX_MODEL) - text_encoder2 = blend_loras( + text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) + text_encoder_2 = blend_loras( server, - text_encoder2, + text_encoder_2, list(zip(lora_models, lora_weights)), "text_encoder", 2, params.is_xl() ) - (text_encoder2, text_encoder2_data) = buffer_external_data_tensors( - text_encoder2 + (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( + text_encoder_2 ) - text_encoder2_names, text_encoder2_values = zip(*text_encoder2_data) - text_encoder2_opts = device.sess_options(cache=False) - text_encoder2_opts.add_external_initializers( - list(text_encoder2_names), list(text_encoder2_values) + text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) + text_encoder_2_opts = device.sess_options(cache=False) + text_encoder_2_opts.add_external_initializers( + list(text_encoder_2_names), list(text_encoder_2_values) ) - text_encoder2_session = InferenceSession( - text_encoder2.SerializeToString(), + text_encoder_2_session = InferenceSession( + text_encoder_2.SerializeToString(), providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder2_opts, - ) - text_encoder2_session._model_path = path.join(model, "text_encoder_2") - components["text_encoder_2"] = ORTModelTextEncoder( - text_encoder2_session, text_encoder2 + sess_options=text_encoder_2_opts, ) + text_encoder_2_session._model_path = path.join(model, "text_encoder_2") + components["text_encoder_2_session"] = text_encoder_2_session # blend and load unet unet = path.join(model, unet_type, ONNX_MODEL) @@ -329,7 +326,7 @@ def load_pipeline( sess_options=unet_opts, ) unet_session._model_path = path.join(model, "unet") - components["unet"] = ORTModelUnet(unet_session, unet_model) + components["unet_session"] = unet_session else: components["unet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( @@ -404,27 +401,29 @@ def load_pipeline( # make sure XL models are actually being used # TODO: why is this needed? - if "text_encoder" in components: + if "text_encoder_session" in components: logger.info( "text encoder matches: %s, %s", - pipe.text_encoder == components["text_encoder"], + pipe.text_encoder.session == components["text_encoder_session"], type(pipe.text_encoder), ) - pipe.text_encoder = components["text_encoder"] + pipe.text_encoder = ORTModelTextEncoder(text_encoder_session, text_encoder) - if "text_encoder_2" in components: + if "text_encoder_2_session" in components: logger.info( "text encoder 2 matches: %s, %s", - pipe.text_encoder_2 == components["text_encoder_2"], + pipe.text_encoder_2.session == components["text_encoder_2_session"], type(pipe.text_encoder_2), ) - pipe.text_encoder_2 = components["text_encoder_2"] + pipe.text_encoder_2 = ORTModelTextEncoder(text_encoder_2_session, text_encoder_2) - if "unet" in components: + if "unet_session" in components: logger.info( - "unet matches: %s, %s", pipe.unet == components["unet"], type(pipe.unet) + "unet matches: %s, %s", + pipe.unet.session == components["unet_session"], + type(pipe.unet), ) - pipe.unet = components["unet"] + pipe.unet = ORTModelUnet(unet_session, unet_model) if not server.show_progress: pipe.set_progress_bar_config(disable=True)