1
0
Fork 0

clean up text encoder loading logic, deduplicate sessions

This commit is contained in:
Sean Sube 2023-09-24 18:01:42 -05:00
parent 85b4245cef
commit a3a04fd1f4
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 75 additions and 101 deletions

View File

@ -294,7 +294,6 @@ def load_controlnet(server, device, params):
def load_text_encoders( def load_text_encoders(
server, device, model: str, inversions, loras, torch_dtype, params server, device, model: str, inversions, loras, torch_dtype, params
): ):
text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
tokenizer = CLIPTokenizer.from_pretrained( tokenizer = CLIPTokenizer.from_pretrained(
model, model,
subfolder="tokenizer", subfolder="tokenizer",
@ -304,14 +303,23 @@ def load_text_encoders(
components = {} components = {}
components["tokenizer"] = tokenizer components["tokenizer"] = tokenizer
if inversions is not None and len(inversions) > 0: text_encoder = load_model(path.join(model, "text_encoder", ONNX_MODEL))
logger.debug("blending Textual Inversions from %s", inversions) text_encoder_2 = None
inversion_names, inversion_weights = zip(*inversions)
if params.is_xl():
text_encoder_2 = load_model(path.join(model, "text_encoder_2", ONNX_MODEL))
# blend embeddings, if any
if inversions is not None and len(inversions) > 0:
inversion_names, inversion_weights = zip(*inversions)
inversion_models = [ inversion_models = [
path.join(server.model_path, "inversion", name) for name in inversion_names path.join(server.model_path, "inversion", name) for name in inversion_names
] ]
logger.debug(
"blending base model %s with embeddings from %s", model, inversion_models
)
# TODO: blend text_encoder_2 as well
text_encoder, tokenizer = blend_textual_inversions( text_encoder, tokenizer = blend_textual_inversions(
server, server,
text_encoder, text_encoder,
@ -326,45 +334,15 @@ def load_text_encoders(
), ),
) )
# should be pretty small and should not need external data # blend LoRAs, if any
if loras is None or len(loras) == 0: if loras is not None and len(loras) > 0:
text_encoder = path.join(model, "text_encoder", ONNX_MODEL)
if params.is_xl():
text_encoder_opts = device.sess_options(cache=False)
text_encoder_session = InferenceSession(
text_encoder,
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_session = InferenceSession(
text_encoder_2,
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts,
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
else:
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder,
provider=device.ort_provider("text-encoder"),
sess_options=device.sess_options(),
)
)
else:
# blend and load text encoder
lora_names, lora_weights = zip(*loras) lora_names, lora_weights = zip(*loras)
lora_models = [ lora_models = [
path.join(server.model_path, "lora", name) for name in lora_names path.join(server.model_path, "lora", name) for name in lora_names
] ]
logger.info("blending base model %s with LoRA models: %s", model, lora_models) logger.info("blending base model %s with LoRA models: %s", model, lora_models)
# blend and load text encoder
text_encoder = blend_loras( text_encoder = blend_loras(
server, server,
text_encoder, text_encoder,
@ -373,23 +351,8 @@ def load_text_encoders(
1 if params.is_xl() else None, 1 if params.is_xl() else None,
params.is_xl(), params.is_xl(),
) )
(text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
text_encoder_names, text_encoder_values = zip(*text_encoder_data)
text_encoder_opts = device.sess_options(cache=False)
text_encoder_opts.add_external_initializers(
list(text_encoder_names), list(text_encoder_values)
)
if params.is_xl(): if params.is_xl():
text_encoder_session = InferenceSession(
text_encoder.SerializeToString(),
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder_session"] = text_encoder_session
text_encoder_2 = path.join(model, "text_encoder_2", ONNX_MODEL)
text_encoder_2 = blend_loras( text_encoder_2 = blend_loras(
server, server,
text_encoder_2, text_encoder_2,
@ -398,37 +361,59 @@ def load_text_encoders(
2, 2,
params.is_xl(), params.is_xl(),
) )
(text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
text_encoder_2
)
text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_opts.add_external_initializers(
list(text_encoder_2_names), list(text_encoder_2_values)
)
text_encoder_2_session = InferenceSession( # prepare external data for sessions
text_encoder_2.SerializeToString(), (text_encoder, text_encoder_data) = buffer_external_data_tensors(text_encoder)
providers=[device.ort_provider("text-encoder")], text_encoder_names, text_encoder_values = zip(*text_encoder_data)
sess_options=text_encoder_2_opts, text_encoder_opts = device.sess_options(cache=False)
) text_encoder_opts.add_external_initializers(
text_encoder_2_session._model_path = path.join(model, "text_encoder_2") list(text_encoder_names), list(text_encoder_values)
components["text_encoder_2_session"] = text_encoder_2_session )
else:
components["text_encoder"] = OnnxRuntimeModel( if params.is_xl():
OnnxRuntimeModel.load_model( # encoder 2 only exists in XL
text_encoder.SerializeToString(), (text_encoder_2, text_encoder_2_data) = buffer_external_data_tensors(
provider=device.ort_provider("text-encoder"), text_encoder_2
sess_options=text_encoder_opts, )
) text_encoder_2_names, text_encoder_2_values = zip(*text_encoder_2_data)
text_encoder_2_opts = device.sess_options(cache=False)
text_encoder_2_opts.add_external_initializers(
list(text_encoder_2_names), list(text_encoder_2_values)
)
# session for te1
text_encoder_session = InferenceSession(
text_encoder,
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_opts,
)
text_encoder_session._model_path = path.join(model, "text_encoder")
components["text_encoder_session"] = text_encoder_session
# session for te2
text_encoder_2_session = InferenceSession(
text_encoder_2,
providers=[device.ort_provider("text-encoder")],
sess_options=text_encoder_2_opts,
)
text_encoder_2_session._model_path = path.join(model, "text_encoder_2")
components["text_encoder_2_session"] = text_encoder_2_session
else:
# session for te
components["text_encoder"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
text_encoder,
provider=device.ort_provider("text-encoder"),
sess_options=text_encoder_opts,
) )
)
return components return components
def load_unet(server, device, model, loras, unet_type, params): def load_unet(server, device, model, loras, unet_type, params):
components = {} components = {}
unet = path.join(model, unet_type, ONNX_MODEL) unet = load_model(path.join(model, unet_type, ONNX_MODEL))
# LoRA blending # LoRA blending
if loras is not None and len(loras) > 0: if loras is not None and len(loras) > 0:
@ -446,37 +431,26 @@ def load_unet(server, device, model, loras, unet_type, params):
"unet", "unet",
xl=params.is_xl(), xl=params.is_xl(),
) )
(unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_names, unet_values = zip(*unet_data)
unet_opts = device.sess_options(cache=False)
unet_opts.add_external_initializers(list(unet_names), list(unet_values))
if params.is_xl(): (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
unet_session = InferenceSession( unet_names, unet_values = zip(*unet_data)
unet_model.SerializeToString(), unet_opts = device.sess_options(cache=False)
providers=[device.ort_provider("unet")], unet_opts.add_external_initializers(list(unet_names), list(unet_values))
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet_session"] = unet_session
else:
components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model(
unet_model.SerializeToString(),
provider=device.ort_provider("unet"),
sess_options=unet_opts,
)
)
# make sure a UNet has been loaded if params.is_xl():
if not params.is_xl() and "unet" not in components: unet_session = InferenceSession(
unet = path.join(model, unet_type, ONNX_MODEL) unet_model.SerializeToString(),
logger.debug("loading UNet (%s) from %s", unet_type, unet) providers=[device.ort_provider("unet")],
sess_options=unet_opts,
)
unet_session._model_path = path.join(model, "unet")
components["unet_session"] = unet_session
else:
components["unet"] = OnnxRuntimeModel( components["unet"] = OnnxRuntimeModel(
OnnxRuntimeModel.load_model( OnnxRuntimeModel.load_model(
unet, unet_model.SerializeToString(),
provider=device.ort_provider("unet"), provider=device.ort_provider("unet"),
sess_options=device.sess_options(), sess_options=unet_opts,
) )
) )