add logging when conv shapes do not match
This commit is contained in:
parent
4db1737f98
commit
512f41135d
|
@ -312,31 +312,36 @@ def blend_loras(
|
|||
logger.trace("found weight initializer: %s", weight_node.name)
|
||||
|
||||
# blending
|
||||
base_weights = numpy_helper.to_array(weight_node)
|
||||
onnx_weights = numpy_helper.to_array(weight_node)
|
||||
logger.trace(
|
||||
"found blended weights for conv: %s, %s",
|
||||
onnx_weights.shape,
|
||||
weights.shape,
|
||||
base_weights.shape,
|
||||
)
|
||||
|
||||
if base_weights.shape[-2:] == (1, 1):
|
||||
if onnx_weights.shape[-2:] == (1, 1):
|
||||
if weights.shape[-2:] == (1, 1):
|
||||
blended = base_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||
blended = onnx_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||
else:
|
||||
blended = base_weights.squeeze((3, 2)) + weights
|
||||
blended = onnx_weights.squeeze((3, 2)) + weights
|
||||
|
||||
blended = np.expand_dims(blended, (2, 3))
|
||||
else:
|
||||
if base_weights.shape != weights.shape:
|
||||
blended = base_weights + weights.reshape(base_weights.shape)
|
||||
if onnx_weights.shape != weights.shape:
|
||||
logger.warning(
|
||||
"reshaping weights for mismatched Conv node: %s, %s",
|
||||
onnx_weights.shape,
|
||||
weights.shape,
|
||||
)
|
||||
blended = onnx_weights + weights.reshape(onnx_weights.shape)
|
||||
else:
|
||||
blended = base_weights + weights
|
||||
blended = onnx_weights + weights
|
||||
|
||||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(
|
||||
blended.astype(base_weights.dtype), weight_node.name
|
||||
blended.astype(onnx_weights.dtype), weight_node.name
|
||||
)
|
||||
del base_model.graph.initializer[weight_idx]
|
||||
base_model.graph.initializer.insert(weight_idx, updated_node)
|
||||
|
@ -355,19 +360,19 @@ def blend_loras(
|
|||
logger.trace("found matmul initializer: %s", matmul_node.name)
|
||||
|
||||
# blending
|
||||
base_weights = numpy_helper.to_array(matmul_node)
|
||||
onnx_weights = numpy_helper.to_array(matmul_node)
|
||||
logger.trace(
|
||||
"found blended weights for matmul: %s, %s",
|
||||
weights.shape,
|
||||
base_weights.shape,
|
||||
onnx_weights.shape,
|
||||
)
|
||||
|
||||
blended = base_weights + weights.transpose()
|
||||
blended = onnx_weights + weights.transpose()
|
||||
logger.trace("blended weight shape: %s", blended.shape)
|
||||
|
||||
# replace the original initializer
|
||||
updated_node = numpy_helper.from_array(
|
||||
blended.astype(base_weights.dtype), matmul_node.name
|
||||
blended.astype(onnx_weights.dtype), matmul_node.name
|
||||
)
|
||||
del base_model.graph.initializer[matmul_idx]
|
||||
base_model.graph.initializer.insert(matmul_idx, updated_node)
|
||||
|
|
Loading…
Reference in New Issue