From 2a6a62ac33287eb96eb8dd9bae7ebc765b68d185 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 26 Feb 2026 17:50:53 +0100 Subject: [PATCH] [MLIR] Fix out-of-bounds crash in matchReduction for ops with fewer yield operands (#183555) `mlir::matchReduction()` accessed `terminatorOp->getOperand(redPos)` without checking that `redPos` is within the terminator's operand count. This caused an assertion failure when the region's block-argument count exceeds the terminator's yield count, e.g. for `linalg.pooling_nhwc_sum` whose kernel region has three block args but yields only one value. Add a bounds check before the operand access so the function returns nullptr gracefully instead of crashing. Fixes #131437 --- mlir/lib/Analysis/SliceAnalysis.cpp | 3 ++- mlir/test/Analysis/test-match-reduction.mlir | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 12dff19ed31d..f388cd8041d6 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -329,7 +329,8 @@ Value mlir::matchReduction(ArrayRef iterCarriedArgs, // Check that the yielded value is in the same position as in // `iterCarriedArgs`. Operation *terminatorOp = combinerOp; - if (terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0]) + if (redPos >= terminatorOp->getNumOperands() || + terminatorOp->getOperand(redPos) != combinerOps.back()->getResults()[0]) return nullptr; return reducedVal; diff --git a/mlir/test/Analysis/test-match-reduction.mlir b/mlir/test/Analysis/test-match-reduction.mlir index b5902db77e89..7faddf902521 100644 --- a/mlir/test/Analysis/test-match-reduction.mlir +++ b/mlir/test/Analysis/test-match-reduction.mlir @@ -112,3 +112,19 @@ func.func @affine_output_dep(%in: memref<512xf32>) { return } +// ----- + +// Verify that matchReduction does not crash when the terminator has fewer +// operands than the number of iteration-carried block arguments (issue #131437). +// expected-remark@below {{Testing function}} +func.func @pooling_nhwc_sum_no_crash(%arg0: tensor<1x1x1x1xf32>, + %arg1: tensor<1x3x3x1xf32>) { + %0 = tensor.empty() : tensor<1x1xf32> + // expected-remark@below {{Reduction NOT found in output #0!}} + // expected-remark@below {{Reduction NOT found in output #1!}} + %1 = linalg.pooling_nhwc_sum + {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} + ins(%arg1, %0 : tensor<1x3x3x1xf32>, tensor<1x1xf32>) + outs(%arg0 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + return +}