clean up text encoder loading logic, deduplicate sessions
This commit is contained in:
parent
85b4245cef
commit
a3a04fd1f4
|
@ -294,7 +294,6 @@ def load_controlnet(server, device, params):
|
||||||
def load_text_encoders(
|
def load_text_encoders(
|
||||||
server, device, model: str, inversions, loras, torch_dtype, params
|
server, device, model: str, inversions, loras, torch_dtype, params
|
||||||
):
|
):
|
||||||
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
model,
|
model,
|
||||||
subfolder="tokenizer",
|
subfolder="tokenizer",
|
||||||
|
@ -304,14 +303,23 @@ def load_text_encoders(
|
||||||
components = {}
|
components = {}
|
||||||
components["tokenizer"] = tokenizer
|
components["tokenizer"] = tokenizer
|
||||||
|
|
||||||
if inversions is not None and len(inversions) > 0:
|
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
|
||||||
logger.debug("blending Textual Inversions from %s", inversions)
|
text_encoder_2 = None
|
||||||
inversion_names, inversion_weights = zip(*inversions)
|
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
|
||||||
|
|
||||||
|
# blend embeddings, if any
|
||||||
|
if inversions is not None and len(inversions) > 0:
|
||||||
|
inversion_names, inversion_weights = zip(*inversions)
|
||||||
inversion_models = [
|
inversion_models = [
|
||||||
path.join(server.model_path, "inversion", name) for name in inversion_names
|
path.join(server.model_path, "inversion", name) for name in inversion_names
|
||||||
]
|
]
|
||||||
|
logger.debug(
|
||||||
|
"blending base model %s with embeddings from %s", model, inversion_models
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: blend text_encoder_2 as well
|
||||||
text_encoder, tokenizer = blend_textual_inversions(
|
text_encoder, tokenizer = blend_textual_inversions(
|
||||||
server,
|
server,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
|
@ -326,45 +334,15 @@ def load_text_encoders(
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# should be pretty small and should not need external data
|
# blend LoRAs, if any
|
||||||
if loras is None or len(loras) == 0:
|
if loras is not None and len(loras) > 0:
|
||||||
text_encoder = path.join(model, "text_encoder", ONNX_MODEL)
|
|
||||||
|
|
||||||
if params.is_xl():
|
|
||||||
text_encoder_opts = device.sess_options(cache=False)
|
|
||||||
text_encoder_session = InferenceSession(
|
|
||||||
text_encoder,
|
|
||||||
providers=[device.ort_provider("text-encoder")],
|
|
||||||
sess_options=text_encoder_opts,
|
|
||||||
)
|
|
||||||
|
|
||||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
|
||||||
|
|
||||||
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
|
||||||
text_encoder_2_opts = device.sess_options(cache=False)
|
|
||||||
text_encoder_2_session = InferenceSession(
|
|
||||||
text_encoder_2,
|
|
||||||
providers=[device.ort_provider("text-encoder")],
|
|
||||||
sess_options=text_encoder_2_opts,
|
|
||||||
)
|
|
||||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
|
||||||
else:
|
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
text_encoder,
|
|
||||||
provider=device.ort_provider("text-encoder"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# blend and load text encoder
|
|
||||||
lora_names, lora_weights = zip(*loras)
|
lora_names, lora_weights = zip(*loras)
|
||||||
lora_models = [
|
lora_models = [
|
||||||
path.join(server.model_path, "lora", name) for name in lora_names
|
path.join(server.model_path, "lora", name) for name in lora_names
|
||||||
]
|
]
|
||||||
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
logger.info("blending base model %s with LoRA models: %s", model, lora_models)
|
||||||
|
|
||||||
|
# blend and load text encoder
|
||||||
text_encoder = blend_loras(
|
text_encoder = blend_loras(
|
||||||
server,
|
server,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
|
@ -373,23 +351,8 @@ def load_text_encoders(
|
||||||
1 if params.is_xl() else None,
|
1 if params.is_xl() else None,
|
||||||
params.is_xl(),
|
params.is_xl(),
|
||||||
)
|
)
|
||||||
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
|
|
||||||
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
|
||||||
text_encoder_opts = device.sess_options(cache=False)
|
|
||||||
text_encoder_opts.add_external_initializers(
|
|
||||||
list(text_encoder_names), list(text_encoder_values)
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.is_xl():
|
if params.is_xl():
|
||||||
text_encoder_session = InferenceSession(
|
|
||||||
text_encoder.SerializeToString(),
|
|
||||||
providers=[device.ort_provider("text-encoder")],
|
|
||||||
sess_options=text_encoder_opts,
|
|
||||||
)
|
|
||||||
text_encoder_session._model_path = path.join(model, "text_encoder")
|
|
||||||
components["text_encoder_session"] = text_encoder_session
|
|
||||||
|
|
||||||
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
|
|
||||||
text_encoder_2 = blend_loras(
|
text_encoder_2 = blend_loras(
|
||||||
server,
|
server,
|
||||||
text_encoder_2,
|
text_encoder_2,
|
||||||
|
@ -398,6 +361,17 @@ def load_text_encoders(
|
||||||
2,
|
2,
|
||||||
params.is_xl(),
|
params.is_xl(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# prepare external data for sessions
|
||||||
|
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
|
||||||
|
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
|
||||||
|
text_encoder_opts = device.sess_options(cache=False)
|
||||||
|
text_encoder_opts.add_external_initializers(
|
||||||
|
list(text_encoder_names), list(text_encoder_values)
|
||||||
|
)
|
||||||
|
|
||||||
|
if params.is_xl():
|
||||||
|
# encoder 2 only exists in XL
|
||||||
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
|
||||||
text_encoder_2
|
text_encoder_2
|
||||||
)
|
)
|
||||||
|
@ -407,17 +381,28 @@ def load_text_encoders(
|
||||||
list(text_encoder_2_names), list(text_encoder_2_values)
|
list(text_encoder_2_names), list(text_encoder_2_values)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# session for te1
|
||||||
|
text_encoder_session = InferenceSession(
|
||||||
|
text_encoder,
|
||||||
|
providers=[device.ort_provider("text-encoder")],
|
||||||
|
sess_options=text_encoder_opts,
|
||||||
|
)
|
||||||
|
text_encoder_session._model_path = path.join(model, "text_encoder")
|
||||||
|
components["text_encoder_session"] = text_encoder_session
|
||||||
|
|
||||||
|
# session for te2
|
||||||
text_encoder_2_session = InferenceSession(
|
text_encoder_2_session = InferenceSession(
|
||||||
text_encoder_2.SerializeToString(),
|
text_encoder_2,
|
||||||
providers=[device.ort_provider("text-encoder")],
|
providers=[device.ort_provider("text-encoder")],
|
||||||
sess_options=text_encoder_2_opts,
|
sess_options=text_encoder_2_opts,
|
||||||
)
|
)
|
||||||
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
|
||||||
components["text_encoder_2_session"] = text_encoder_2_session
|
components["text_encoder_2_session"] = text_encoder_2_session
|
||||||
else:
|
else:
|
||||||
|
# session for te
|
||||||
components["text_encoder"] = OnnxRuntimeModel(
|
components["text_encoder"] = OnnxRuntimeModel(
|
||||||
OnnxRuntimeModel.load_model(
|
OnnxRuntimeModel.load_model(
|
||||||
text_encoder.SerializeToString(),
|
text_encoder,
|
||||||
provider=device.ort_provider("text-encoder"),
|
provider=device.ort_provider("text-encoder"),
|
||||||
sess_options=text_encoder_opts,
|
sess_options=text_encoder_opts,
|
||||||
)
|
)
|
||||||
|
@ -428,7 +413,7 @@ def load_text_encoders(
|
||||||
|
|
||||||
def load_unet(server, device, model, loras, unet_type, params):
|
def load_unet(server, device, model, loras, unet_type, params):
|
||||||
components = {}
|
components = {}
|
||||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
unet = load_model(path.join(model, unet_type, ONNX_MODEL))
|
||||||
|
|
||||||
# LoRA blending
|
# LoRA blending
|
||||||
if loras is not None and len(loras) > 0:
|
if loras is not None and len(loras) > 0:
|
||||||
|
@ -446,6 +431,7 @@ def load_unet(server, device, model, loras, unet_type, params):
|
||||||
"unet",
|
"unet",
|
||||||
xl=params.is_xl(),
|
xl=params.is_xl(),
|
||||||
)
|
)
|
||||||
|
|
||||||
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
|
||||||
unet_names, unet_values = zip(*unet_data)
|
unet_names, unet_values = zip(*unet_data)
|
||||||
unet_opts = device.sess_options(cache=False)
|
unet_opts = device.sess_options(cache=False)
|
||||||
|
@ -468,18 +454,6 @@ def load_unet(server, device, model, loras, unet_type, params):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure a UNet has been loaded
|
|
||||||
if not params.is_xl() and "unet" not in components:
|
|
||||||
unet = path.join(model, unet_type, ONNX_MODEL)
|
|
||||||
logger.debug("loading UNet (%s) from %s", unet_type, unet)
|
|
||||||
components["unet"] = OnnxRuntimeModel(
|
|
||||||
OnnxRuntimeModel.load_model(
|
|
||||||
unet,
|
|
||||||
provider=device.ort_provider("unet"),
|
|
||||||
sess_options=device.sess_options(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue