fix(api): build sum tokens for TIs using emb_params key
This commit is contained in:
parent
6c00b2d87d
commit
5f5418132b
|
@ -48,8 +48,11 @@ def blend_textual_inversions(
|
|||
if len(keys) == 1 and keys[0].startswith("<") and keys[0].endswith(">"):
|
||||
logger.debug("detected Textual Inversion concept: %s", keys)
|
||||
inversion_format = "concept"
|
||||
elif "emb_params" in keys:
|
||||
logger.debug("detected Textual Inversion parameter embeddings: %s", keys)
|
||||
inversion_format = "parameters"
|
||||
elif "string_to_token" in keys and "string_to_param" in keys:
|
||||
logger.debug("detected Textual Inversion embeddings: %s", keys)
|
||||
logger.debug("detected Textual Inversion token embeddings: %s", keys)
|
||||
inversion_format = "embeddings"
|
||||
else:
|
||||
logger.error(
|
||||
|
@ -73,6 +76,36 @@ def blend_textual_inversions(
|
|||
embeds[token] += layer
|
||||
else:
|
||||
embeds[token] = layer
|
||||
elif inversion_format == "parameters":
|
||||
emb_params = loaded_embeds["emb_params"]
|
||||
|
||||
num_tokens = emb_params.shape[0]
|
||||
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
||||
|
||||
sum_layer = np.zeros(emb_params[0, :].shape)
|
||||
|
||||
for i in range(num_tokens):
|
||||
token = f"{base_token}-{i}"
|
||||
layer = emb_params[i, :].numpy().astype(dtype)
|
||||
layer *= weight
|
||||
|
||||
sum_layer += layer
|
||||
if token in embeds:
|
||||
embeds[token] += layer
|
||||
else:
|
||||
embeds[token] = layer
|
||||
|
||||
# add base and sum tokens to embeds
|
||||
if base_token in embeds:
|
||||
embeds[base_token] += sum_layer
|
||||
else:
|
||||
embeds[base_token] = sum_layer
|
||||
|
||||
sum_token = f"{base_token}-all"
|
||||
if sum_token in embeds:
|
||||
embeds[sum_token] += sum_layer
|
||||
else:
|
||||
embeds[sum_token] = sum_layer
|
||||
elif inversion_format == "embeddings":
|
||||
string_to_token = loaded_embeds["string_to_token"]
|
||||
string_to_param = loaded_embeds["string_to_param"]
|
||||
|
|
Loading…
Reference in New Issue