more lint
This commit is contained in:
parent
e104c81e19
commit
2cb0a6be3c
|
@ -24,13 +24,13 @@ def blend_textual_inversions(
|
||||||
dtype = np.float
|
dtype = np.float
|
||||||
embeds = {}
|
embeds = {}
|
||||||
|
|
||||||
for name, weight, base_token, format in inversions:
|
for name, weight, base_token, inversion_format in inversions:
|
||||||
if base_token is None:
|
if base_token is None:
|
||||||
base_token = name
|
base_token = name
|
||||||
|
|
||||||
if format is None:
|
if inversion_format is None:
|
||||||
# TODO: detect concept format
|
# TODO: detect concept format
|
||||||
format = "embeddings"
|
inversion_format = "embeddings"
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"blending Textual Inversion %s with weight of %s for token %s",
|
"blending Textual Inversion %s with weight of %s for token %s",
|
||||||
|
@ -39,7 +39,7 @@ def blend_textual_inversions(
|
||||||
base_token,
|
base_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if format == "concept":
|
if inversion_format == "concept":
|
||||||
# TODO: this should be done in fetch, maybe
|
# TODO: this should be done in fetch, maybe
|
||||||
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
|
||||||
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")
|
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")
|
||||||
|
@ -58,7 +58,7 @@ def blend_textual_inversions(
|
||||||
embeds[token] += layer
|
embeds[token] += layer
|
||||||
else:
|
else:
|
||||||
embeds[token] = layer
|
embeds[token] = layer
|
||||||
elif format == "embeddings":
|
elif inversion_format == "embeddings":
|
||||||
loaded_embeds = torch.load(name)
|
loaded_embeds = torch.load(name)
|
||||||
|
|
||||||
string_to_token = loaded_embeds["string_to_token"]
|
string_to_token = loaded_embeds["string_to_token"]
|
||||||
|
@ -91,7 +91,7 @@ def blend_textual_inversions(
|
||||||
else:
|
else:
|
||||||
embeds[sum_token] = sum_layer
|
embeds[sum_token] = sum_layer
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown Textual Inversion format: {format}")
|
raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
|
||||||
|
|
||||||
# add the tokens to the tokenizer
|
# add the tokens to the tokenizer
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
Loading…
Reference in New Issue