fix(api): move model conversion messages to trace level
This commit is contained in:
parent
c397c1e42a
commit
9f0a6f134e
|
@ -96,7 +96,7 @@ def blend_loras(
|
||||||
try:
|
try:
|
||||||
if len(up_weight.size()) == 2:
|
if len(up_weight.size()) == 2:
|
||||||
# blend for nn.Linear
|
# blend for nn.Linear
|
||||||
logger.debug(
|
logger.trace(
|
||||||
"blending weights for Linear node: %s, %s, %s",
|
"blending weights for Linear node: %s, %s, %s",
|
||||||
down_weight.shape,
|
down_weight.shape,
|
||||||
up_weight.shape,
|
up_weight.shape,
|
||||||
|
@ -106,7 +106,7 @@ def blend_loras(
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (1, 1):
|
||||||
# blend for nn.Conv2d 1x1
|
# blend for nn.Conv2d 1x1
|
||||||
logger.debug(
|
logger.trace(
|
||||||
"blending weights for Conv 1x1 node: %s, %s, %s",
|
"blending weights for Conv 1x1 node: %s, %s, %s",
|
||||||
down_weight.shape,
|
down_weight.shape,
|
||||||
up_weight.shape,
|
up_weight.shape,
|
||||||
|
@ -123,7 +123,7 @@ def blend_loras(
|
||||||
np_weights = weights.numpy() * (alpha / dim)
|
np_weights = weights.numpy() * (alpha / dim)
|
||||||
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3):
|
elif len(up_weight.size()) == 4 and up_weight.shape[-2:] == (3, 3):
|
||||||
# blend for nn.Conv2d 3x3
|
# blend for nn.Conv2d 3x3
|
||||||
logger.debug(
|
logger.trace(
|
||||||
"blending weights for Conv 3x3 node: %s, %s, %s",
|
"blending weights for Conv 3x3 node: %s, %s, %s",
|
||||||
down_weight.shape,
|
down_weight.shape,
|
||||||
up_weight.shape,
|
up_weight.shape,
|
||||||
|
@ -150,7 +150,7 @@ def blend_loras(
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("error blending weights for key %s", base_key)
|
logger.exception("error blending weights for key %s", base_key)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
"updating %s of %s initializers: %s",
|
"updating %s of %s initializers: %s",
|
||||||
len(blended.keys()),
|
len(blended.keys()),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
|
@ -169,7 +169,7 @@ def blend_loras(
|
||||||
conv_key = base_key + "_Conv"
|
conv_key = base_key + "_Conv"
|
||||||
matmul_key = base_key + "_MatMul"
|
matmul_key = base_key + "_MatMul"
|
||||||
|
|
||||||
logger.debug(
|
logger.trace(
|
||||||
"key %s has conv: %s, matmul: %s",
|
"key %s has conv: %s, matmul: %s",
|
||||||
base_key,
|
base_key,
|
||||||
conv_key in fixed_node_names,
|
conv_key in fixed_node_names,
|
||||||
|
@ -179,20 +179,19 @@ def blend_loras(
|
||||||
if conv_key in fixed_node_names:
|
if conv_key in fixed_node_names:
|
||||||
conv_idx = fixed_node_names.index(conv_key)
|
conv_idx = fixed_node_names.index(conv_key)
|
||||||
conv_node = base_model.graph.node[conv_idx]
|
conv_node = base_model.graph.node[conv_idx]
|
||||||
logger.debug("found conv node: %s", conv_node.name)
|
logger.trace("found conv node %s using %s", conv_node.name, conv_node.input)
|
||||||
|
|
||||||
# find weight initializer
|
# find weight initializer
|
||||||
logger.debug("conv inputs: %s", conv_node.input)
|
|
||||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||||
weight_name = fix_initializer_name(weight_name)
|
weight_name = fix_initializer_name(weight_name)
|
||||||
|
|
||||||
weight_idx = fixed_initializer_names.index(weight_name)
|
weight_idx = fixed_initializer_names.index(weight_name)
|
||||||
weight_node = base_model.graph.initializer[weight_idx]
|
weight_node = base_model.graph.initializer[weight_idx]
|
||||||
logger.debug("found weight initializer: %s", weight_node.name)
|
logger.trace("found weight initializer: %s", weight_node.name)
|
||||||
|
|
||||||
# blending
|
# blending
|
||||||
base_weights = numpy_helper.to_array(weight_node)
|
base_weights = numpy_helper.to_array(weight_node)
|
||||||
logger.debug(
|
logger.trace(
|
||||||
"found blended weights for conv: %s, %s",
|
"found blended weights for conv: %s, %s",
|
||||||
weights.shape,
|
weights.shape,
|
||||||
base_weights.shape,
|
base_weights.shape,
|
||||||
|
@ -200,7 +199,7 @@ def blend_loras(
|
||||||
|
|
||||||
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||||
blended = np.expand_dims(blended, (2, 3))
|
blended = np.expand_dims(blended, (2, 3))
|
||||||
logger.debug("blended weight shape: %s", blended.shape)
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(blended, weight_node.name)
|
updated_node = numpy_helper.from_array(blended, weight_node.name)
|
||||||
|
@ -209,26 +208,27 @@ def blend_loras(
|
||||||
elif matmul_key in fixed_node_names:
|
elif matmul_key in fixed_node_names:
|
||||||
weight_idx = fixed_node_names.index(matmul_key)
|
weight_idx = fixed_node_names.index(matmul_key)
|
||||||
weight_node = base_model.graph.node[weight_idx]
|
weight_node = base_model.graph.node[weight_idx]
|
||||||
logger.debug("found matmul node: %s", weight_node.name)
|
logger.trace(
|
||||||
|
"found matmul node %s using %s", weight_node.name, weight_node.input
|
||||||
|
)
|
||||||
|
|
||||||
# find the MatMul initializer
|
# find the MatMul initializer
|
||||||
logger.debug("matmul inputs: %s", weight_node.input)
|
|
||||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||||
|
|
||||||
matmul_idx = fixed_initializer_names.index(matmul_name)
|
matmul_idx = fixed_initializer_names.index(matmul_name)
|
||||||
matmul_node = base_model.graph.initializer[matmul_idx]
|
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||||
logger.debug("found matmul initializer: %s", matmul_node.name)
|
logger.trace("found matmul initializer: %s", matmul_node.name)
|
||||||
|
|
||||||
# blending
|
# blending
|
||||||
base_weights = numpy_helper.to_array(matmul_node)
|
base_weights = numpy_helper.to_array(matmul_node)
|
||||||
logger.debug(
|
logger.trace(
|
||||||
"found blended weights for matmul: %s, %s",
|
"found blended weights for matmul: %s, %s",
|
||||||
weights.shape,
|
weights.shape,
|
||||||
base_weights.shape,
|
base_weights.shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
blended = base_weights + weights.transpose()
|
blended = base_weights + weights.transpose()
|
||||||
logger.debug("blended weight shape: %s", blended.shape)
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(blended, matmul_node.name)
|
updated_node = numpy_helper.from_array(blended, matmul_node.name)
|
||||||
|
@ -237,7 +237,7 @@ def blend_loras(
|
||||||
else:
|
else:
|
||||||
logger.warning("could not find any nodes for %s", base_key)
|
logger.warning("could not find any nodes for %s", base_key)
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
"node counts: %s -> %s, %s -> %s",
|
"node counts: %s -> %s, %s -> %s",
|
||||||
len(fixed_initializer_names),
|
len(fixed_initializer_names),
|
||||||
len(base_model.graph.initializer),
|
len(base_model.graph.initializer),
|
||||||
|
|
|
@ -63,7 +63,7 @@ def blend_textual_inversions(
|
||||||
trained_embeds = string_to_param[trained_token]
|
trained_embeds = string_to_param[trained_token]
|
||||||
|
|
||||||
num_tokens = trained_embeds.shape[0]
|
num_tokens = trained_embeds.shape[0]
|
||||||
logger.debug("generating %s layer tokens", num_tokens)
|
logger.debug("generating %s layer tokens for %s", num_tokens, name)
|
||||||
|
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
token = f"{base_token or name}-{i}"
|
token = f"{base_token or name}-{i}"
|
||||||
|
@ -77,7 +77,7 @@ def blend_textual_inversions(
|
||||||
raise ValueError(f"unknown Textual Inversion format: {format}")
|
raise ValueError(f"unknown Textual Inversion format: {format}")
|
||||||
|
|
||||||
# add the tokens to the tokenizer
|
# add the tokens to the tokenizer
|
||||||
logger.info(
|
logger.debug(
|
||||||
"found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()
|
"found embeddings for %s tokens: %s", len(embeds.keys()), embeds.keys()
|
||||||
)
|
)
|
||||||
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
num_added_tokens = tokenizer.add_tokens(list(embeds.keys()))
|
||||||
|
@ -86,7 +86,7 @@ def blend_textual_inversions(
|
||||||
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("added %s tokens", num_added_tokens)
|
logger.trace("added %s tokens", num_added_tokens)
|
||||||
|
|
||||||
# resize the token embeddings
|
# resize the token embeddings
|
||||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
@ -103,7 +103,7 @@ def blend_textual_inversions(
|
||||||
|
|
||||||
for token, weights in embeds.items():
|
for token, weights in embeds.items():
|
||||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
logger.debug("embedding %s weights for token %s", weights.shape, token)
|
logger.trace("embedding %s weights for token %s", weights.shape, token)
|
||||||
embedding_weights[token_id] = weights
|
embedding_weights[token_id] = weights
|
||||||
|
|
||||||
# replace embedding_node
|
# replace embedding_node
|
||||||
|
@ -115,7 +115,7 @@ def blend_textual_inversions(
|
||||||
new_initializer = numpy_helper.from_array(
|
new_initializer = numpy_helper.from_array(
|
||||||
embedding_weights.astype(np.float32), embedding_node.name
|
embedding_weights.astype(np.float32), embedding_node.name
|
||||||
)
|
)
|
||||||
logger.debug("new initializer data type: %s", new_initializer.data_type)
|
logger.trace("new initializer data type: %s", new_initializer.data_type)
|
||||||
del text_encoder.graph.initializer[i]
|
del text_encoder.graph.initializer[i]
|
||||||
text_encoder.graph.initializer.insert(i, new_initializer)
|
text_encoder.graph.initializer.insert(i, new_initializer)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue