add logging when conv shapes do not match
This commit is contained in:
parent
4db1737f98
commit
512f41135d
|
@ -312,31 +312,36 @@ def blend_loras(
|
||||||
logger.trace("found weight initializer: %s", weight_node.name)
|
logger.trace("found weight initializer: %s", weight_node.name)
|
||||||
|
|
||||||
# blending
|
# blending
|
||||||
base_weights = numpy_helper.to_array(weight_node)
|
onnx_weights = numpy_helper.to_array(weight_node)
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"found blended weights for conv: %s, %s",
|
"found blended weights for conv: %s, %s",
|
||||||
|
onnx_weights.shape,
|
||||||
weights.shape,
|
weights.shape,
|
||||||
base_weights.shape,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if base_weights.shape[-2:] == (1, 1):
|
if onnx_weights.shape[-2:] == (1, 1):
|
||||||
if weights.shape[-2:] == (1, 1):
|
if weights.shape[-2:] == (1, 1):
|
||||||
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||||
else:
|
else:
|
||||||
blended = base_weights.squeeze((3, 2)) + weights
|
blended = onnx_weights.squeeze((3, 2)) + weights
|
||||||
|
|
||||||
blended = np.expand_dims(blended, (2, 3))
|
blended = np.expand_dims(blended, (2, 3))
|
||||||
else:
|
else:
|
||||||
if base_weights.shape != weights.shape:
|
if onnx_weights.shape != weights.shape:
|
||||||
blended = base_weights + weights.reshape(base_weights.shape)
|
logger.warning(
|
||||||
|
"reshaping weights for mismatched Conv node: %s, %s",
|
||||||
|
onnx_weights.shape,
|
||||||
|
weights.shape,
|
||||||
|
)
|
||||||
|
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
||||||
else:
|
else:
|
||||||
blended = base_weights + weights
|
blended = onnx_weights + weights
|
||||||
|
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(
|
updated_node = numpy_helper.from_array(
|
||||||
blended.astype(base_weights.dtype), weight_node.name
|
blended.astype(onnx_weights.dtype), weight_node.name
|
||||||
)
|
)
|
||||||
del base_model.graph.initializer[weight_idx]
|
del base_model.graph.initializer[weight_idx]
|
||||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||||
|
@ -355,19 +360,19 @@ def blend_loras(
|
||||||
logger.trace("found matmul initializer: %s", matmul_node.name)
|
logger.trace("found matmul initializer: %s", matmul_node.name)
|
||||||
|
|
||||||
# blending
|
# blending
|
||||||
base_weights = numpy_helper.to_array(matmul_node)
|
onnx_weights = numpy_helper.to_array(matmul_node)
|
||||||
logger.trace(
|
logger.trace(
|
||||||
"found blended weights for matmul: %s, %s",
|
"found blended weights for matmul: %s, %s",
|
||||||
weights.shape,
|
weights.shape,
|
||||||
base_weights.shape,
|
onnx_weights.shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
blended = base_weights + weights.transpose()
|
blended = onnx_weights + weights.transpose()
|
||||||
logger.trace("blended weight shape: %s", blended.shape)
|
logger.trace("blended weight shape: %s", blended.shape)
|
||||||
|
|
||||||
# replace the original initializer
|
# replace the original initializer
|
||||||
updated_node = numpy_helper.from_array(
|
updated_node = numpy_helper.from_array(
|
||||||
blended.astype(base_weights.dtype), matmul_node.name
|
blended.astype(onnx_weights.dtype), matmul_node.name
|
||||||
)
|
)
|
||||||
del base_model.graph.initializer[matmul_idx]
|
del base_model.graph.initializer[matmul_idx]
|
||||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||||
|
|
Loading…
Reference in New Issue