diff --git a/api/onnx_web/diffusers/load.py b/api/onnx_web/diffusers/load.py index bc303579..a904aa36 100644 --- a/api/onnx_web/diffusers/load.py +++ b/api/onnx_web/diffusers/load.py @@ -294,7 +294,6 @@ def load_controlnet(server, device, params): def load_text_encoders( server, device, model: str, inversions, loras, torch_dtype, params ): - text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) tokenizer = CLIPTokenizer.from_pretrained( model, subfolder="tokenizer", @@ -304,14 +303,23 @@ def load_text_encoders( components = {} components["tokenizer"] = tokenizer - if inversions is not None and len(inversions) > 0: - logger.debug("blending Textual Inversions from %s", inversions) - inversion_names, inversion_weights = zip(*inversions) + text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL)) + text_encoder_2 = None + if params.is_xl(): + text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL)) + + # blend embeddings, if any + if inversions is not None and len(inversions) > 0: + inversion_names, inversion_weights = zip(*inversions) inversion_models = [ path.join(server.model_path, "inversion", name) for name in inversion_names ] + logger.debug( + "blending base model %s with embeddings from %s", model, inversion_models + ) + # TODO: blend text_encoder_2 as well text_encoder, tokenizer = blend_textual_inversions( server, text_encoder, @@ -326,45 +334,15 @@ def load_text_encoders( ), ) - # should be pretty small and should not need external data - if loras is None or len(loras) == 0: - text_encoder = path.join(model, "text_encoder", ONNX_MODEL) - - if params.is_xl(): - text_encoder_opts = device.sess_options(cache=False) - text_encoder_session = InferenceSession( - text_encoder, - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_opts, - ) - - text_encoder_session._model_path = path.join(model, "text_encoder") - - text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) - text_encoder_2_opts = device.sess_options(cache=False) - text_encoder_2_session = InferenceSession( - text_encoder_2, - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_2_opts, - ) - text_encoder_2_session._model_path = path.join(model, "text_encoder_2") - else: - components["text_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - text_encoder, - provider=device.ort_provider("text-encoder"), - sess_options=device.sess_options(), - ) - ) - - else: - # blend and load text encoder + # blend LoRAs, if any + if loras is not None and len(loras) > 0: lora_names, lora_weights = zip(*loras) lora_models = [ path.join(server.model_path, "lora", name) for name in lora_names ] logger.info("blending base model %s with LoRA models: %s", model, lora_models) + # blend and load text encoder text_encoder = blend_loras( server, text_encoder, @@ -373,23 +351,8 @@ def load_text_encoders( 1 if params.is_xl() else None, params.is_xl(), ) - (text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder) - text_encoder_names, text_encoder_values = zip(*text_encoder_data) - text_encoder_opts = device.sess_options(cache=False) - text_encoder_opts.add_external_initializers( - list(text_encoder_names), list(text_encoder_values) - ) if params.is_xl(): - text_encoder_session = InferenceSession( - text_encoder.SerializeToString(), - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_opts, - ) - text_encoder_session._model_path = path.join(model, "text_encoder") - components["text_encoder_session"] = text_encoder_session - - text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL) text_encoder_2 = blend_loras( server, text_encoder_2, @@ -398,37 +361,59 @@ def load_text_encoders( 2, params.is_xl(), ) - (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( - text_encoder_2 - ) - text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) - text_encoder_2_opts = device.sess_options(cache=False) - text_encoder_2_opts.add_external_initializers( - list(text_encoder_2_names), list(text_encoder_2_values) - ) - text_encoder_2_session = InferenceSession( - text_encoder_2.SerializeToString(), - providers=[device.ort_provider("text-encoder")], - sess_options=text_encoder_2_opts, - ) - text_encoder_2_session._model_path = path.join(model, "text_encoder_2") - components["text_encoder_2_session"] = text_encoder_2_session - else: - components["text_encoder"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - text_encoder.SerializeToString(), - provider=device.ort_provider("text-encoder"), - sess_options=text_encoder_opts, - ) + # prepare external data for sessions + (text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder) + text_encoder_names, text_encoder_values = zip(*text_encoder_data) + text_encoder_opts = device.sess_options(cache=False) + text_encoder_opts.add_external_initializers( + list(text_encoder_names), list(text_encoder_values) + ) + + if params.is_xl(): + # encoder 2 only exists in XL + (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors( + text_encoder_2 + ) + text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data) + text_encoder_2_opts = device.sess_options(cache=False) + text_encoder_2_opts.add_external_initializers( + list(text_encoder_2_names), list(text_encoder_2_values) + ) + + # session for te1 + text_encoder_session = InferenceSession( + text_encoder, + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_opts, + ) + text_encoder_session._model_path = path.join(model, "text_encoder") + components["text_encoder_session"] = text_encoder_session + + # session for te2 + text_encoder_2_session = InferenceSession( + text_encoder_2, + providers=[device.ort_provider("text-encoder")], + sess_options=text_encoder_2_opts, + ) + text_encoder_2_session._model_path = path.join(model, "text_encoder_2") + components["text_encoder_2_session"] = text_encoder_2_session + else: + # session for te + components["text_encoder"] = OnnxRuntimeModel( + OnnxRuntimeModel.load_model( + text_encoder, + provider=device.ort_provider("text-encoder"), + sess_options=text_encoder_opts, ) + ) return components def load_unet(server, device, model, loras, unet_type, params): components = {} - unet = path.join(model, unet_type, ONNX_MODEL) + unet = load_model(path.join(model, unet_type, ONNX_MODEL)) # LoRA blending if loras is not None and len(loras) > 0: @@ -446,37 +431,26 @@ def load_unet(server, device, model, loras, unet_type, params): "unet", xl=params.is_xl(), ) - (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) - unet_names, unet_values = zip(*unet_data) - unet_opts = device.sess_options(cache=False) - unet_opts.add_external_initializers(list(unet_names), list(unet_values)) - if params.is_xl(): - unet_session = InferenceSession( - unet_model.SerializeToString(), - providers=[device.ort_provider("unet")], - sess_options=unet_opts, - ) - unet_session._model_path = path.join(model, "unet") - components["unet_session"] = unet_session - else: - components["unet"] = OnnxRuntimeModel( - OnnxRuntimeModel.load_model( - unet_model.SerializeToString(), - provider=device.ort_provider("unet"), - sess_options=unet_opts, - ) - ) + (unet_model, unet_data) = buffer_external_data_tensors(blended_unet) + unet_names, unet_values = zip(*unet_data) + unet_opts = device.sess_options(cache=False) + unet_opts.add_external_initializers(list(unet_names), list(unet_values)) - # make sure a UNet has been loaded - if not params.is_xl() and "unet" not in components: - unet = path.join(model, unet_type, ONNX_MODEL) - logger.debug("loading UNet (%s) from %s", unet_type, unet) + if params.is_xl(): + unet_session = InferenceSession( + unet_model.SerializeToString(), + providers=[device.ort_provider("unet")], + sess_options=unet_opts, + ) + unet_session._model_path = path.join(model, "unet") + components["unet_session"] = unet_session + else: components["unet"] = OnnxRuntimeModel( OnnxRuntimeModel.load_model( - unet, + unet_model.SerializeToString(), provider=device.ort_provider("unet"), - sess_options=device.sess_options(), + sess_options=unet_opts, ) )