From b3ca423a78ec436936834170cd8b96a25d3f8b7e Mon Sep 17 00:00:00 2001 From: Nishant Patel Date: Wed, 1 Apr 2026 14:59:08 -0700 Subject: [PATCH] [MLIR][Vector] Enhance vector.multi_reduction unrolling to handle scalar result (#188633) Previously, UnrollMultiReductionPattern bailed out when all the dimensions were reduced to a scalar. This PR adds support for this case by tiling the source vector and chaining partial reductions through the accumulator operand. --- .../Vector/Transforms/VectorUnroll.cpp | 31 +++++++++++++++---- .../Dialect/Vector/vector-unroll-options.mlir | 14 +++++---- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 37a691c4fce7..ec08f01d2a4b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -381,21 +381,40 @@ struct UnrollMultiReductionPattern LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { - auto resultType = reductionOp->getResult(0).getType(); - if (resultType.isIntOrFloat()) { - return rewriter.notifyMatchFailure(reductionOp, - "Unrolling scalars is not supported"); - } std::optional> targetShape = getTargetShape(options, reductionOp); if (!targetShape) return failure(); SmallVector originalSize = *reductionOp.getShapeForUnroll(); + Location loc = reductionOp.getLoc(); + auto resultType = reductionOp->getResult(0).getType(); + + // Handle scalar result case: all dimensions are reduced. + // Each source tile is reduced to a scalar, and partial results are + // chained through the accumulator operand. + if (resultType.isIntOrFloat()) { + Value accumulator = reductionOp.getAcc(); + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { + SmallVector operandStrides(offsets.size(), 1); + Value slicedOperand = + rewriter.createOrFold( + loc, reductionOp.getSource(), offsets, *targetShape, + operandStrides); + Operation *newOp = cloneOpWithOperandsAndTypes( + rewriter, loc, reductionOp, {slicedOperand, accumulator}, + resultType); + accumulator = newOp->getResult(0); + } + rewriter.replaceOp(reductionOp, accumulator); + return success(); + } + + // Vector result case. llvm::MapVector< SmallVector, Value, llvm::DenseMap, unsigned, OffsetMapInfo>> accCache; - Location loc = reductionOp.getLoc(); // Stride of the ratios, this gives us the offsets of sliceCount in a basis // of multiples of the targetShape. diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 14bc81a06c09..036d09053552 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -245,15 +245,17 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> // CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> // CHECK: return %[[V2]] : vector<4xf32> -// This is a negative test case to ensure that further unrolling is not performed. Since the vector.multi_reduction -// operation has already been unrolled, attempting additional unrolling should not be allowed. -func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 { +func.func @vector_multi_reduction_scalar(%v: vector<4x2xf32>, %acc: f32) -> f32 { %0 = vector.multi_reduction #vector.kind, %v, %acc [0, 1] : vector<4x2xf32> to f32 return %0 : f32 } -// CHECK-LABEL: func @negative_vector_multi_reduction -// CHECK-NEXT: %[[R0:.*]] = vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32 -// CHECK-NEXT: return %[[R0]] : f32 +// CHECK-LABEL: func @vector_multi_reduction_scalar +// CHECK-SAME: %[[V:.*]]: vector<4x2xf32>, %[[ACC:.*]]: f32 +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK: %[[R0:.*]] = vector.multi_reduction , %[[S0]], %[[ACC]] [0, 1] : vector<2x2xf32> to f32 +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> +// CHECK: %[[R1:.*]] = vector.multi_reduction , %[[S1]], %[[R0]] [0, 1] : vector<2x2xf32> to f32 +// CHECK: return %[[R1]] : f32 func.func @vector_reduction(%v : vector<8xf32>) -> f32 { %0 = vector.reduction , %v : vector<8xf32> into f32