fix(api): embed Inversion concepts using their name
This commit is contained in:
parent
19d4d554c3
commit
fe498b16f0
|
@ -44,10 +44,10 @@ def blend_textual_inversions(
|
||||||
if inversion_format == "concept":
|
if inversion_format == "concept":
|
||||||
# TODO: this should be done in fetch, maybe
|
# TODO: this should be done in fetch, maybe
|
||||||
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
||||||
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")
|
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") # not strictly needed
|
||||||
|
|
||||||
with open(token_file, "r") as f:
|
with open(token_file, "r") as f:
|
||||||
token = base_token or f.read()
|
token = f.read()
|
||||||
|
|
||||||
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
loaded_embeds = load_tensor(embeds_file, map_location=device)
|
||||||
if loaded_embeds is None:
|
if loaded_embeds is None:
|
||||||
|
@ -59,7 +59,13 @@ def blend_textual_inversions(
|
||||||
|
|
||||||
layer = loaded_embeds[trained_token].numpy().astype(dtype)
|
layer = loaded_embeds[trained_token].numpy().astype(dtype)
|
||||||
layer *= weight
|
layer *= weight
|
||||||
if trained_token in embeds:
|
|
||||||
|
if base_token in embeds:
|
||||||
|
embeds[base_token] += layer
|
||||||
|
else:
|
||||||
|
embeds[base_token] = layer
|
||||||
|
|
||||||
|
if token in embeds:
|
||||||
embeds[token] += layer
|
embeds[token] += layer
|
||||||
else:
|
else:
|
||||||
embeds[token] = layer
|
embeds[token] = layer
|
||||||
|
|
Loading…
Reference in New Issue