1
0
Fork 0

more lint

This commit is contained in:
Sean Sube 2023-03-18 11:55:06 -05:00
parent e104c81e19
commit 2cb0a6be3c
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 6 additions and 6 deletions

View File

@ -24,13 +24,13 @@ def blend_textual_inversions(
dtype = np.float dtype = np.float
embeds = {} embeds = {}
for name, weight, base_token, format in inversions: for name, weight, base_token, inversion_format in inversions:
if base_token is None: if base_token is None:
base_token = name base_token = name
if format is None: if inversion_format is None:
# TODO: detect concept format # TODO: detect concept format
format = "embeddings" inversion_format = "embeddings"
logger.info( logger.info(
"blending Textual Inversion %s with weight of %s for token %s", "blending Textual Inversion %s with weight of %s for token %s",
@ -39,7 +39,7 @@ def blend_textual_inversions(
base_token, base_token,
) )
if format == "concept": if inversion_format == "concept":
# TODO: this should be done in fetch, maybe # TODO: this should be done in fetch, maybe
embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin") embeds_file = hf_hub_download(repo_id=name, filename="learned_embeds.bin")
token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt") token_file = hf_hub_download(repo_id=name, filename="token_identifier.txt")
@ -58,7 +58,7 @@ def blend_textual_inversions(
embeds[token] += layer embeds[token] += layer
else: else:
embeds[token] = layer embeds[token] = layer
elif format == "embeddings": elif inversion_format == "embeddings":
loaded_embeds = torch.load(name) loaded_embeds = torch.load(name)
string_to_token = loaded_embeds["string_to_token"] string_to_token = loaded_embeds["string_to_token"]
@ -91,7 +91,7 @@ def blend_textual_inversions(
else: else:
embeds[sum_token] = sum_layer embeds[sum_token] = sum_layer
else: else:
raise ValueError(f"unknown Textual Inversion format: {format}") raise ValueError(f"unknown Textual Inversion format: {inversion_format}")
# add the tokens to the tokenizer # add the tokens to the tokenizer
logger.debug( logger.debug(