[MLIR] Fix out-of-bounds crash in matchReduction for ops with fewer yield operands (#183555)

`mlir::matchReduction()` accessed `terminatorOp->getOperand(redPos)`
without checking that `redPos` is within the terminator's operand count.
This caused an assertion failure when the region's block-argument count
exceeds the terminator's yield count, e.g. for `linalg.pooling_nhwc_sum`
whose kernel region has three block args but yields only one value.

Add a bounds check before the operand access so the function returns
nullptr gracefully instead of crashing.

Fixes #131437
This commit is contained in:
Mehdi Amini 2026-02-26 17:50:53 +01:00 committed by GitHub
parent b9ad15feb5
commit 2a6a62ac33
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 1 deletions

View File

@ -329,7 +329,8 @@ Value mlir::matchReduction(ArrayRef<BlockArgument> iterCarriedArgs,
// Check that the yielded value is in the same position as in
// `iterCarriedArgs`.
Operation *terminatorOp = combinerOp;
if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
if (redPos >= terminatorOp->getNumOperands() ||
terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0])
return nullptr;
return reducedVal;

View File

@ -112,3 +112,19 @@ func.func @affine_output_dep(%in: memref<512xf32>) {
return
}
// -----
// Verify that matchReduction does not crash when the terminator has fewer
// operands than the number of iteration-carried block arguments (issue #131437).
// expected-remark@below {{Testing function}}
func.func @pooling_nhwc_sum_no_crash(%arg0: tensor<1x1x1x1xf32>,
%arg1: tensor<1x3x3x1xf32>) {
%0 = tensor.empty() : tensor<1x1xf32>
// expected-remark@below {{Reduction NOT found in output #0!}}
// expected-remark@below {{Reduction NOT found in output #1!}}
%1 = linalg.pooling_nhwc_sum
{dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
ins(%arg1, %0 : tensor<1x3x3x1xf32>, tensor<1x1xf32>)
outs(%arg0 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
return
}