1
0
Fork 0

fix(api): embed Inversion concepts using their name

This commit is contained in:
Sean Sube 2023-03-19 19:01:22 -05:00
parent 19d4d554c3
commit fe498b16f0
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 9 additions and 3 deletions

View File

@ -44,10 +44,10 @@ def blend_textual_inversions(
if inversion_format == "concept":
# TODO: this should be done in fetch, maybe
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") # not strictly needed
with open(token_file, "r") as f:
token = base_token or f.read()
token = f.read()
loaded_embeds = load_tensor(embeds_file, map_location=device)
if loaded_embeds is None:
@ -59,7 +59,13 @@ def blend_textual_inversions(
layer = loaded_embeds[trained_token].numpy().astype(dtype)
layer *= weight
if trained_token in embeds:
if base_token in embeds:
embeds[base_token] += layer
else:
embeds[base_token] = layer
if token in embeds:
embeds[token] += layer
else:
embeds[token] = layer