From 613a5c555ebffd7f32ad48de7253e8c25fe627a4 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Wed, 4 Mar 2026 11:13:11 -0500 Subject: [PATCH] [mlir][vector] Replace OneDimMultiReductionToTwoDim with OneDimMultiReductionToReduction (#184241) The `OneDimMultiReductionToTwoDim` pattern had some issues. For the input program: ```mlir func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 { %0 = vector.multi_reduction , %arg0, %acc [0] : vector<8xf32> to f32 return %0 : f32 } ``` * when lowering using the inner-parallel strategy, the compiler would essentially produce scalar code: ```mlir func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 { %0 = vector.shape_cast %arg0 : vector<8xf32> to vector<1x8xf32> %1 = vector.broadcast %arg1 : f32 to vector<1xf32> %2 = vector.transpose %0, [1, 0] : vector<1x8xf32> to vector<8x1xf32> %3 = vector.extract %2[0] : vector<1xf32> from vector<8x1xf32> %4 = arith.addf %3, %1 : vector<1xf32> %5 = vector.extract %2[1] : vector<1xf32> from vector<8x1xf32> %6 = arith.addf %5, %4 : vector<1xf32> ... (repeats for all 8 elements) ... %17 = vector.extract %2[7] : vector<1xf32> from vector<8x1xf32> %18 = arith.addf %17, %16 : vector<1xf32> %19 = vector.extract %18[0] : f32 from vector<1xf32> return %19 : f32 } ``` * when lowering using the inner-reduction strategy, the compiler would first unnecessarily transform it into a 2-D multi_reduction operation <1x8xf32> and then extract an <8xf32> vector and apply reduction. The canonicalization and folding would lead to the following final result: ```mlir func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 { %0 = vector.reduction , %arg0, %arg1 : vector<8xf32> into f32 return %0 : f32 } ``` Now, after this change: * when lowering the compiler now produces for both strategies in one step. ``` func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 { %0 = vector.reduction , %arg0, %arg1 : vector<8xf32> into f32 return %0 : f32 } ``` This pattern is also useful for an ongoing refactoring that is happening in the multi_reduction patterns. It is the only pattern that increases multi_reduction in rank and would lead to an infinite loop when attempting to reach a fixed point once we generalize other unrolling patterns. Assisted-by: Claude --- .../Vector/TransformOps/VectorTransformOps.td | 18 ++-- .../Vector/Transforms/LoweringPatterns.h | 11 +-- .../TransformOps/VectorTransformOps.cpp | 4 +- .../Transforms/LowerVectorMultiReduction.cpp | 93 ++++++------------- mlir/test/Dialect/LLVM/transform-e2e.mlir | 2 +- .../test/Dialect/Vector/transform-vector.mlir | 2 +- ...ir => vector-multi-reduction-reorder.mlir} | 50 +--------- .../vector-multi-reduction-unrolling.mlir | 33 +++++-- .../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir | 2 +- .../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir | 2 +- .../Linalg/CPU/test-matmul-masked-vec.mlir | 2 +- .../python/dialects/transform_vector_ext.py | 8 +- 12 files changed, 85 insertions(+), 142 deletions(-) rename mlir/test/Dialect/Vector/{vector-multi-reduction-reorder-and-expand.mlir => vector-multi-reduction-reorder.mlir} (51%) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 9fec5804d0b3..dcd5f6ff3ad7 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -223,18 +223,17 @@ def ApplyMaterializeMasksPatternsOp : Op]> { let description = [{ Indicates that vector multi_reduction-like operations should be transformed such that all reduction dimensions become innermost or - outermost, and 1-D reductions are lifted to 2-D. + outermost, depending on `lowering_strategy`. This populates the patterns from - `populateVectorMultiReductionReorderAndExpandPatterns`, i.e.: + `populateVectorMultiReductionReorderPatterns`, i.e.: * `InnerOuterDimReductionConversion` - * `OneDimMultiReductionToTwoDim` }]; let arguments = (ins DefaultValuedAttr]> { let description = [{ - Indicates that 2-D vector multi_reduction operations should be unrolled - into either a sequence of vector.reduction ops (innerreduction) or - element-wise arith ops (innerparallel). + Indicates that vector multi_reduction operations should be unrolled. + 1-D multi_reductions are converted directly to vector.reduction. + 2-D multi_reductions are unrolled into either a sequence of + vector.reduction ops (innerreduction) or element-wise arith ops + (innerparallel). This populates the patterns from `populateVectorMultiReductionUnrollingPatterns`, i.e.: + * `OneDimMultiReductionToReduction` * `TwoDimMultiReductionToReduction` (innerreduction) * `TwoDimMultiReductionToElementWise` (innerparallel) }]; diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index a933f68732a4..aa75eff409ef 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -66,13 +66,7 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, /// Rewrites vector.multi_reduction such that all reduction dimensions are /// either innermost or outermost, by adding the proper vector.transpose /// operations. -/// -/// [OneDimMultiReductionToTwoDim] -/// For cases that reduce to 1-D vector reduction (and are thus missing -/// either a parallel or a reduction), we lift them back up to 2-D with a simple -/// vector.shape_cast to vector<1xk> so that the other patterns can kick in, -/// thus fully exiting out of the vector.multi_reduction abstraction. -void populateVectorMultiReductionReorderAndExpandPatterns( +void populateVectorMultiReductionReorderPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit = 1); @@ -89,6 +83,9 @@ void populateVectorMultiReductionFlatteningPatterns( /// Populate the pattern set with the following patterns: /// +/// [OneDimMultiReductionToReduction] +/// Converts 1-D vector.multi_reduction to vector.reduction. +/// /// [TwoDimMultiReductionToElementWise] /// Once in 2-D vector.multi_reduction form, with an **outermost** reduction /// dimension, unroll the outer dimension to obtain a sequence of 1-D vector diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 9da4be88586f..312bd28ad48c 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -129,11 +129,11 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( //===----------------------------------------------------------------------===// // Multi-reduction patterns //===----------------------------------------------------------------------===// -void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns( +void transform::ApplyReorderMultiReductionPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); - vector::populateVectorMultiReductionReorderAndExpandPatterns( + vector::populateVectorMultiReductionReorderPatterns( patterns, vectorTransformOptions.vectorMultiReductionLowering); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 0d9ff95e1279..76599822fbfe 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -375,7 +375,7 @@ struct TwoDimMultiReductionToElementWise } }; -/// Lowers 2D vector.multi_reduction to a squence of vector.reduction Ops +/// Lowers 2D vector.multi_reduction to a sequence of vector.reduction Ops. /// /// The reduction dimension must be the inner-most dimension. /// @@ -443,75 +443,42 @@ struct TwoDimMultiReductionToReduction } }; -/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d -/// form with both a single parallel and reduction dimension. -/// This is achieved with a simple vector.shape_cast that inserts a leading 1. -/// The case with a single parallel dimension is a noop and folds away -/// separately. -struct OneDimMultiReductionToTwoDim - : public OpRewritePattern { - using Base::Base; +/// Converts 1D vector.multi_reduction directly to vector.reduction. +/// +/// Example: +/// ```mlir +/// // Before +/// %r = vector.multi_reduction , %v, %acc [0] : vector to f32 +/// +/// // After +/// %r = vector.reduction , %v, %acc : vector into f32 +/// ``` +struct OneDimMultiReductionToReduction + : public vector::MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; - LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, - PatternRewriter &rewriter) const override { + FailureOr + matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp, + vector::MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - // Rank-1 or bail. if (srcRank != 1) return failure(); - // Vector mask setup. - OpBuilder::InsertionGuard guard(rewriter); - auto maskableOp = - cast(multiReductionOp.getOperation()); - Operation *rootOp; - Value mask; - if (maskableOp.isMasked()) { - rewriter.setInsertionPoint(maskableOp.getMaskingOp()); - rootOp = maskableOp.getMaskingOp(); - mask = maskableOp.getMaskingOp().getMask(); - } else { - rootOp = multiReductionOp; - } + if (!multiReductionOp.isReducedDim(0)) + return failure(); auto loc = multiReductionOp.getLoc(); - auto srcVectorType = multiReductionOp.getSourceVectorType(); - auto srcShape = srcVectorType.getShape(); - auto castedType = VectorType::get( - ArrayRef{1, srcShape.back()}, srcVectorType.getElementType(), - ArrayRef{false, srcVectorType.getScalableDims().back()}); + Value mask = maskingOp ? maskingOp.getMask() : Value(); - auto accType = - VectorType::get(ArrayRef{1}, srcVectorType.getElementType()); - assert(!llvm::isa(multiReductionOp.getDestType()) && - "multi_reduction with a single dimension expects a scalar result"); + Operation *reductionOp = vector::ReductionOp::create( + rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(), + multiReductionOp.getAcc()); - // If the unique dim is reduced and we insert a parallel in front, we need a - // {false, true} mask. - SmallVector reductionMask{false, true}; + if (mask) + reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); - /// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0) - Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType, - multiReductionOp.getSource()); - Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType, - multiReductionOp.getAcc()); - Value castMask; - if (maskableOp.isMasked()) { - auto maskType = llvm::cast(mask.getType()); - auto castMaskType = VectorType::get( - ArrayRef{1, maskType.getShape().back()}, - maskType.getElementType(), - ArrayRef{false, maskType.getScalableDims().back()}); - castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask); - } - - Operation *newOp = vector::MultiDimReductionOp::create( - rewriter, loc, cast, castAcc, reductionMask, - multiReductionOp.getKind()); - newOp = vector::maskOperation(rewriter, newOp, castMask); - - rewriter.replaceOpWithNewOp(rootOp, newOp->getResult(0), - ArrayRef{0}); - return success(); + return reductionOp->getResult(0); } }; @@ -527,7 +494,7 @@ struct LowerVectorMultiReductionPass MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns( + mlir::vector::populateVectorMultiReductionReorderPatterns( patterns, this->loweringStrategy); if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); @@ -552,10 +519,9 @@ struct LowerVectorMultiReductionPass } // namespace -void mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns( +void mlir::vector::populateVectorMultiReductionReorderPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); patterns.add(patterns.getContext(), options, benefit); } @@ -569,6 +535,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns( void mlir::vector::populateVectorMultiReductionUnrollingPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); if (options == VectorMultiReductionLowering ::InnerReduction) patterns.add(patterns.getContext(), benefit); diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir index ab58dda91a91..bf7eba6e5017 100644 --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -30,7 +30,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" transform.apply_patterns.vector.transfer_permutation_patterns - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel" + transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index a37105d57321..4dc11c26e83f 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -39,7 +39,7 @@ module attributes {transform.with_named_sequence} { } : !transform.any_op transform.apply_patterns to %f { - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel" + transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" } : !transform.any_op diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-reorder.mlir similarity index 51% rename from mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir rename to mlir/test/Dialect/Vector/vector-multi-reduction-reorder.mlir index 7f41f7e9e1dd..0a22205f61f9 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-reorder-and-expand.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-reorder.mlir @@ -36,50 +36,10 @@ func.func @transpose_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf3 return %0 : vector<4xf32> } -// ALL-LABEL: func @one_dim_to_two_dim -// ALL-SAME: %[[INPUT:.+]]: vector<8xf32> -// ALL-SAME: %[[ACC:.+]]: f32 -func.func @one_dim_to_two_dim(%arg0: vector<8xf32>, %acc: f32) -> f32 { - // ALL: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<8xf32> to vector<1x8xf32> - // ALL: %[[BROADCAST:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32> - // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction , %[[CAST]], %[[BROADCAST]] [1] - // INNER_REDUCTION: %[[SCALAR:.+]] = vector.extract %[[RESULT]][0] - // INNER_PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[CAST]], [1, 0] - // INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction , %[[TRANSPOSED]], %[[BROADCAST]] [0] - // INNER_PARALLEL: %[[SCALAR:.+]] = vector.extract %[[RESULT]][0] +// ALL-LABEL: func @negative_one_dim +func.func @negative_one_dim(%arg0: vector<8xf32>, %acc: f32) -> f32 { + // ALL: vector.multi_reduction , {{.+}} [0] : vector<8xf32> to f32 %0 = vector.multi_reduction , %arg0, %acc [0] : vector<8xf32> to f32 - // ALL: return %[[SCALAR]] - return %0 : f32 -} - -// INNER_REDUCTION-LABEL: func @one_dim_to_two_dim_scalable -// INNER_REDUCTION-SAME: %[[INPUT:.+]]: vector<[4]xf32> -// INNER_REDUCTION-SAME: %[[ACC:.+]]: f32 -func.func @one_dim_to_two_dim_scalable(%arg0: vector<[4]xf32>, %acc: f32) -> f32 { - // INNER_REDUCTION: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<[4]xf32> to vector<1x[4]xf32> - // INNER_REDUCTION: %[[BROADCAST:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32> - // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction , %[[CAST]], %[[BROADCAST]] [1] - %0 = vector.multi_reduction , %arg0, %acc [0] : vector<[4]xf32> to f32 - // INNER_REDUCTION: %[[EXTRACT:.+]] = vector.extract %[[RESULT]][0] - // INNER_REDUCTION: return %[[EXTRACT]] - return %0 : f32 -} - -// INNER_REDUCTION-LABEL: func @one_dim_to_two_dim_masked -// INNER_REDUCTION-SAME: %[[INPUT:.+]]: vector<8xf32> -// INNER_REDUCTION-SAME: %[[ACC:.+]]: f32 -// INNER_REDUCTION-SAME: %[[MASK:.+]]: vector<8xi1> -func.func @one_dim_to_two_dim_masked(%arg0: vector<8xf32>, %acc: f32, %mask: vector<8xi1>) -> f32 { - // INNER_REDUCTION: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<8xf32> to vector<1x8xf32> - // INNER_REDUCTION: %[[BROADCAST_ACC:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32> - // INNER_REDUCTION: %[[BROADCAST_MASK:.+]] = vector.broadcast %[[MASK]] : vector<8xi1> to vector<1x8xi1> - // INNER_REDUCTION: %[[RESULT:.+]] = vector.mask %[[BROADCAST_MASK]] { - // INNER_REDUCTION: vector.multi_reduction , %[[CAST]], %[[BROADCAST_ACC]] [1] - %0 = vector.mask %mask { - vector.multi_reduction , %arg0, %acc [0] : vector<8xf32> to f32 - } : vector<8xi1> -> f32 - // INNER_REDUCTION: %[[EXTRACT:.+]] = vector.extract %[[RESULT]][0] - // INNER_REDUCTION: return %[[EXTRACT]] return %0 : f32 } @@ -87,7 +47,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction" + transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield } @@ -95,7 +55,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel" + transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" } : !transform.op<"func.func"> transform.yield } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir index bc0d192e012e..447416ccba63 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir @@ -1,15 +1,32 @@ // RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefixes=INNER_REDUCTION,ALL // RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefixes=INNER_PARALLEL,ALL -// ALL-LABEL: func @negative_rank1_and_rank3 -func.func @negative_rank1_and_rank3( - %rank1: vector<8xf32>, %rank1_acc: f32, - %rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> (f32, vector<2x3xf32>) { - // ALL: vector.multi_reduction , {{.+}} [0] : vector<8xf32> to f32 - %0 = vector.multi_reduction , %rank1, %rank1_acc [0] : vector<8xf32> to f32 +// ALL-LABEL: func @one_dim_reduction +// ALL-SAME: %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32 +func.func @one_dim_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 { + // ALL: %[[RESULT:.+]] = vector.reduction , %[[INPUT]], %[[ACC]] : vector<8xf32> into f32 + %0 = vector.multi_reduction , %arg0, %acc [0] : vector<8xf32> to f32 + // ALL: return %[[RESULT]] + return %0 : f32 +} + +// ALL-LABEL: func @one_dim_reduction_masked +// ALL-SAME: %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32, %[[MASK:.+]]: vector<8xi1> +func.func @one_dim_reduction_masked(%arg0: vector<8xf32>, %acc: f32, %mask: vector<8xi1>) -> f32 { + // ALL: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction , %[[INPUT]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32 + %0 = vector.mask %mask { + vector.multi_reduction , %arg0, %acc [0] : vector<8xf32> to f32 + } : vector<8xi1> -> f32 + // ALL: return %[[RESULT]] + return %0 : f32 +} + +// ALL-LABEL: func @negative_rank3 +func.func @negative_rank3( + %rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> vector<2x3xf32> { // ALL: vector.multi_reduction , {{.+}} [2] : vector<2x3x4xf32> to vector<2x3xf32> - %1 = vector.multi_reduction , %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32> - return %0, %1 : f32, vector<2x3xf32> + %0 = vector.multi_reduction , %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32> + return %0 : vector<2x3xf32> } // ALL-LABEL: func @inner_reduction_2d diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir index 25b65080339d..a7b0b27ca5fb 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir @@ -150,7 +150,7 @@ module attributes {transform.with_named_sequence} { // Step 3: Lower vector.multi_reduction transform.apply_patterns to %func { transform.apply_patterns.vector.lower_masked_transfers - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction" + transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" } : !transform.op<"func.func"> diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir index 6072b44adf4f..4adc68966f17 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir @@ -155,7 +155,7 @@ module attributes {transform.with_named_sequence} { // Step 3: Lower vector.multi_reduction transform.apply_patterns to %func { transform.apply_patterns.vector.lower_masked_transfers - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction" + transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" } : !transform.op<"func.func"> diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir index 3c4f10316d0f..0883e7b698f5 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir @@ -53,7 +53,7 @@ module attributes {transform.with_named_sequence} { %func_op = transform.get_parent_op %0 : (!transform.any_op) -> !transform.op<"func.func"> transform.structured.vectorize %0 vector_sizes [4, 4, 2] : !transform.any_op transform.apply_patterns to %func_op { - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction" + transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" } : !transform.op<"func.func"> diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index 8a3091d0b1b0..a3c53a45048b 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -87,11 +87,11 @@ def enum_configurable_patterns(): lowering_strategy=vector.VectorContractLowering.ParallelArith ) - # CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims - vector.ApplyReorderAndExpandMultiReductionPatternsOp() - # CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims + # CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims + vector.ApplyReorderMultiReductionPatternsOp() + # CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims # CHECK-SAME: lowering_strategy = innerreduction - vector.ApplyReorderAndExpandMultiReductionPatternsOp( + vector.ApplyReorderMultiReductionPatternsOp( lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction )