1
0
Fork 0

fix(api): build sum tokens for TIs using emb_params key

This commit is contained in:
Sean Sube 2023-04-22 12:07:52 -05:00
parent 6c00b2d87d
commit 5f5418132b
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 34 additions and 1 deletions

View File

@ -48,8 +48,11 @@ def blend_textual_inversions(
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"): if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
logger.debug("detected Textual Inversion concept: %s", keys) logger.debug("detected Textual Inversion concept: %s", keys)
inversion_format = "concept" inversion_format = "concept"
elif "emb_params" in keys:
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
inversion_format = "parameters"
elif "string_to_token" in keys and "string_to_param" in keys: elif "string_to_token" in keys and "string_to_param" in keys:
logger.debug("detected Textual Inversion embeddings: %s", keys) logger.debug("detected Textual Inversion token embeddings: %s", keys)
inversion_format = "embeddings" inversion_format = "embeddings"
else: else:
logger.error( logger.error(
@ -73,6 +76,36 @@ def blend_textual_inversions(
embeds[token] += layer embeds[token] += layer
else: else:
embeds[token] = layer embeds[token] = layer
elif inversion_format == "parameters":
emb_params = loaded_embeds["emb_params"]
num_tokens = emb_params.shape[0]
logger.debug("generating %s layer tokens for %s", num_tokens, name)
sum_layer = np.zeros(emb_params[0, :].shape)
for i in range(num_tokens):
token = f"{base_token}-{i}"
layer = emb_params[i, :].numpy().astype(dtype)
layer *= weight
sum_layer += layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer
# add base and sum tokens to embeds
if base_token in embeds:
embeds[base_token] += sum_layer
else:
embeds[base_token] = sum_layer
sum_token = f"{base_token}-all"
if sum_token in embeds:
embeds[sum_token] += sum_layer
else:
embeds[sum_token] = sum_layer
elif inversion_format == "embeddings": elif inversion_format == "embeddings":
string_to_token = loaded_embeds["string_to_token"] string_to_token = loaded_embeds["string_to_token"]
string_to_param = loaded_embeds["string_to_param"] string_to_param = loaded_embeds["string_to_param"]