1
0
Fork 0

add logging when conv shapes do not match

This commit is contained in:
Sean Sube 2023-04-10 08:14:41 -05:00
parent 4db1737f98
commit 512f41135d
Signed by: ssube
GPG Key ID: 3EED7B957D362AF1
1 changed files with 18 additions and 13 deletions

View File

@ -312,31 +312,36 @@ def blend_loras(
logger.trace("found weight initializer: %s", weight_node.name) logger.trace("found weight initializer: %s", weight_node.name)
# blending # blending
base_weights = numpy_helper.to_array(weight_node) onnx_weights = numpy_helper.to_array(weight_node)
logger.trace( logger.trace(
"found blended weights for conv: %s, %s", "found blended weights for conv: %s, %s",
onnx_weights.shape,
weights.shape, weights.shape,
base_weights.shape,
) )
if base_weights.shape[-2:] == (1, 1): if onnx_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1): if weights.shape[-2:] == (1, 1):
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else: else:
blended = base_weights.squeeze((3, 2)) + weights blended = onnx_weights.squeeze((3, 2)) + weights
blended = np.expand_dims(blended, (2, 3)) blended = np.expand_dims(blended, (2, 3))
else: else:
if base_weights.shape != weights.shape: if onnx_weights.shape != weights.shape:
blended = base_weights + weights.reshape(base_weights.shape) logger.warning(
"reshaping weights for mismatched Conv node: %s, %s",
onnx_weights.shape,
weights.shape,
)
blended = onnx_weights + weights.reshape(onnx_weights.shape)
else: else:
blended = base_weights + weights blended = onnx_weights + weights
logger.trace("blended weight shape: %s", blended.shape) logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array( updated_node = numpy_helper.from_array(
blended.astype(base_weights.dtype), weight_node.name blended.astype(onnx_weights.dtype), weight_node.name
) )
del base_model.graph.initializer[weight_idx] del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, updated_node) base_model.graph.initializer.insert(weight_idx, updated_node)
@ -355,19 +360,19 @@ def blend_loras(
logger.trace("found matmul initializer: %s", matmul_node.name) logger.trace("found matmul initializer: %s", matmul_node.name)
# blending # blending
base_weights = numpy_helper.to_array(matmul_node) onnx_weights = numpy_helper.to_array(matmul_node)
logger.trace( logger.trace(
"found blended weights for matmul: %s, %s", "found blended weights for matmul: %s, %s",
weights.shape, weights.shape,
base_weights.shape, onnx_weights.shape,
) )
blended = base_weights + weights.transpose() blended = onnx_weights + weights.transpose()
logger.trace("blended weight shape: %s", blended.shape) logger.trace("blended weight shape: %s", blended.shape)
# replace the original initializer # replace the original initializer
updated_node = numpy_helper.from_array( updated_node = numpy_helper.from_array(
blended.astype(base_weights.dtype), matmul_node.name blended.astype(onnx_weights.dtype), matmul_node.name
) )
del base_model.graph.initializer[matmul_idx] del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, updated_node) base_model.graph.initializer.insert(matmul_idx, updated_node)