[mlir][vector] Replace OneDimMultiReductionToTwoDim with OneDimMultiReductionToReduction (#184241)
The `OneDimMultiReductionToTwoDim` pattern had some issues. For the
input program:
```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 {
%0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
return %0 : f32
}
```
* when lowering using the inner-parallel strategy, the compiler would
essentially produce scalar code:
```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
%0 = vector.shape_cast %arg0 : vector<8xf32> to vector<1x8xf32>
%1 = vector.broadcast %arg1 : f32 to vector<1xf32>
%2 = vector.transpose %0, [1, 0] : vector<1x8xf32> to vector<8x1xf32>
%3 = vector.extract %2[0] : vector<1xf32> from vector<8x1xf32>
%4 = arith.addf %3, %1 : vector<1xf32>
%5 = vector.extract %2[1] : vector<1xf32> from vector<8x1xf32>
%6 = arith.addf %5, %4 : vector<1xf32>
... (repeats for all 8 elements) ...
%17 = vector.extract %2[7] : vector<1xf32> from vector<8x1xf32>
%18 = arith.addf %17, %16 : vector<1xf32>
%19 = vector.extract %18[0] : f32 from vector<1xf32>
return %19 : f32
}
```
* when lowering using the inner-reduction strategy, the compiler would
first unnecessarily transform it into a 2-D multi_reduction operation
<1x8xf32> and then extract an <8xf32> vector and apply reduction. The
canonicalization and folding would lead to the following final result:
```mlir
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<8xf32> into f32
return %0 : f32
}
```
Now, after this change:
* when lowering the compiler now produces for both strategies in one
step.
```
func.func @rank1_multi_reduction(%arg0: vector<8xf32>, %arg1: f32) -> f32 {
%0 = vector.reduction <add>, %arg0, %arg1 : vector<8xf32> into f32
return %0 : f32
}
```
This pattern is also useful for an ongoing refactoring that is happening
in the multi_reduction patterns. It is the only pattern that increases
multi_reduction in rank and would lead to an infinite loop when
attempting to reach a fixed point once we generalize other unrolling
patterns.
Assisted-by: Claude
This commit is contained in:
parent
7b72b5fde4
commit
613a5c555e
@ -223,18 +223,17 @@ def ApplyMaterializeMasksPatternsOp : Op<Transform_Dialect,
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
|
||||
"apply_patterns.vector.reorder_and_expand_multi_reduction_dims",
|
||||
def ApplyReorderMultiReductionPatternsOp: Op<Transform_Dialect,
|
||||
"apply_patterns.vector.reorder_multi_reduction_dims",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
let description = [{
|
||||
Indicates that vector multi_reduction-like operations should be
|
||||
transformed such that all reduction dimensions become innermost or
|
||||
outermost, and 1-D reductions are lifted to 2-D.
|
||||
outermost, depending on `lowering_strategy`.
|
||||
|
||||
This populates the patterns from
|
||||
`populateVectorMultiReductionReorderAndExpandPatterns`, i.e.:
|
||||
`populateVectorMultiReductionReorderPatterns`, i.e.:
|
||||
* `InnerOuterDimReductionConversion`
|
||||
* `OneDimMultiReductionToTwoDim`
|
||||
}];
|
||||
|
||||
let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
|
||||
@ -267,12 +266,15 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
|
||||
"apply_patterns.vector.multi_reduction_unrolling",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
let description = [{
|
||||
Indicates that 2-D vector multi_reduction operations should be unrolled
|
||||
into either a sequence of vector.reduction ops (innerreduction) or
|
||||
element-wise arith ops (innerparallel).
|
||||
Indicates that vector multi_reduction operations should be unrolled.
|
||||
1-D multi_reductions are converted directly to vector.reduction.
|
||||
2-D multi_reductions are unrolled into either a sequence of
|
||||
vector.reduction ops (innerreduction) or element-wise arith ops
|
||||
(innerparallel).
|
||||
|
||||
This populates the patterns from
|
||||
`populateVectorMultiReductionUnrollingPatterns`, i.e.:
|
||||
* `OneDimMultiReductionToReduction`
|
||||
* `TwoDimMultiReductionToReduction` (innerreduction)
|
||||
* `TwoDimMultiReductionToElementWise` (innerparallel)
|
||||
}];
|
||||
|
||||
@ -66,13 +66,7 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
|
||||
/// Rewrites vector.multi_reduction such that all reduction dimensions are
|
||||
/// either innermost or outermost, by adding the proper vector.transpose
|
||||
/// operations.
|
||||
///
|
||||
/// [OneDimMultiReductionToTwoDim]
|
||||
/// For cases that reduce to 1-D vector<k> reduction (and are thus missing
|
||||
/// either a parallel or a reduction), we lift them back up to 2-D with a simple
|
||||
/// vector.shape_cast to vector<1xk> so that the other patterns can kick in,
|
||||
/// thus fully exiting out of the vector.multi_reduction abstraction.
|
||||
void populateVectorMultiReductionReorderAndExpandPatterns(
|
||||
void populateVectorMultiReductionReorderPatterns(
|
||||
RewritePatternSet &patterns, VectorMultiReductionLowering options,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
@ -89,6 +83,9 @@ void populateVectorMultiReductionFlatteningPatterns(
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [OneDimMultiReductionToReduction]
|
||||
/// Converts 1-D vector.multi_reduction to vector.reduction.
|
||||
///
|
||||
/// [TwoDimMultiReductionToElementWise]
|
||||
/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction
|
||||
/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector
|
||||
|
||||
@ -129,11 +129,11 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Multi-reduction patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
|
||||
void transform::ApplyReorderMultiReductionPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
vector::VectorTransformsOptions vectorTransformOptions;
|
||||
vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
|
||||
vector::populateVectorMultiReductionReorderAndExpandPatterns(
|
||||
vector::populateVectorMultiReductionReorderPatterns(
|
||||
patterns, vectorTransformOptions.vectorMultiReductionLowering);
|
||||
}
|
||||
|
||||
|
||||
@ -375,7 +375,7 @@ struct TwoDimMultiReductionToElementWise
|
||||
}
|
||||
};
|
||||
|
||||
/// Lowers 2D vector.multi_reduction to a squence of vector.reduction Ops
|
||||
/// Lowers 2D vector.multi_reduction to a sequence of vector.reduction Ops.
|
||||
///
|
||||
/// The reduction dimension must be the inner-most dimension.
|
||||
///
|
||||
@ -443,75 +443,42 @@ struct TwoDimMultiReductionToReduction
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts 1d vector.multi_reduction with a single reduction dimension to a 2d
|
||||
/// form with both a single parallel and reduction dimension.
|
||||
/// This is achieved with a simple vector.shape_cast that inserts a leading 1.
|
||||
/// The case with a single parallel dimension is a noop and folds away
|
||||
/// separately.
|
||||
struct OneDimMultiReductionToTwoDim
|
||||
: public OpRewritePattern<vector::MultiDimReductionOp> {
|
||||
using Base::Base;
|
||||
/// Converts 1D vector.multi_reduction directly to vector.reduction.
|
||||
///
|
||||
/// Example:
|
||||
/// ```mlir
|
||||
/// // Before
|
||||
/// %r = vector.multi_reduction <add>, %v, %acc [0] : vector<Nxf32> to f32
|
||||
///
|
||||
/// // After
|
||||
/// %r = vector.reduction <add>, %v, %acc : vector<Nxf32> into f32
|
||||
/// ```
|
||||
struct OneDimMultiReductionToReduction
|
||||
: public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
|
||||
using MaskableOpRewritePattern::MaskableOpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
FailureOr<Value>
|
||||
matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
|
||||
vector::MaskingOpInterface maskingOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
|
||||
// Rank-1 or bail.
|
||||
if (srcRank != 1)
|
||||
return failure();
|
||||
|
||||
// Vector mask setup.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
auto maskableOp =
|
||||
cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
|
||||
Operation *rootOp;
|
||||
Value mask;
|
||||
if (maskableOp.isMasked()) {
|
||||
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
|
||||
rootOp = maskableOp.getMaskingOp();
|
||||
mask = maskableOp.getMaskingOp().getMask();
|
||||
} else {
|
||||
rootOp = multiReductionOp;
|
||||
}
|
||||
if (!multiReductionOp.isReducedDim(0))
|
||||
return failure();
|
||||
|
||||
auto loc = multiReductionOp.getLoc();
|
||||
auto srcVectorType = multiReductionOp.getSourceVectorType();
|
||||
auto srcShape = srcVectorType.getShape();
|
||||
auto castedType = VectorType::get(
|
||||
ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(),
|
||||
ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
|
||||
Value mask = maskingOp ? maskingOp.getMask() : Value();
|
||||
|
||||
auto accType =
|
||||
VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
|
||||
assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
|
||||
"multi_reduction with a single dimension expects a scalar result");
|
||||
Operation *reductionOp = vector::ReductionOp::create(
|
||||
rewriter, loc, multiReductionOp.getKind(), multiReductionOp.getSource(),
|
||||
multiReductionOp.getAcc());
|
||||
|
||||
// If the unique dim is reduced and we insert a parallel in front, we need a
|
||||
// {false, true} mask.
|
||||
SmallVector<bool, 2> reductionMask{false, true};
|
||||
if (mask)
|
||||
reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
|
||||
|
||||
/// vector.extract(vector.multi_reduce(vector.shape_cast(v, 1xk)), 0)
|
||||
Value cast = vector::ShapeCastOp::create(rewriter, loc, castedType,
|
||||
multiReductionOp.getSource());
|
||||
Value castAcc = vector::BroadcastOp::create(rewriter, loc, accType,
|
||||
multiReductionOp.getAcc());
|
||||
Value castMask;
|
||||
if (maskableOp.isMasked()) {
|
||||
auto maskType = llvm::cast<VectorType>(mask.getType());
|
||||
auto castMaskType = VectorType::get(
|
||||
ArrayRef<int64_t>{1, maskType.getShape().back()},
|
||||
maskType.getElementType(),
|
||||
ArrayRef<bool>{false, maskType.getScalableDims().back()});
|
||||
castMask = vector::BroadcastOp::create(rewriter, loc, castMaskType, mask);
|
||||
}
|
||||
|
||||
Operation *newOp = vector::MultiDimReductionOp::create(
|
||||
rewriter, loc, cast, castAcc, reductionMask,
|
||||
multiReductionOp.getKind());
|
||||
newOp = vector::maskOperation(rewriter, newOp, castMask);
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::ExtractOp>(rootOp, newOp->getResult(0),
|
||||
ArrayRef<int64_t>{0});
|
||||
return success();
|
||||
return reductionOp->getResult(0);
|
||||
}
|
||||
};
|
||||
|
||||
@ -527,7 +494,7 @@ struct LowerVectorMultiReductionPass
|
||||
MLIRContext *context = op->getContext();
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
|
||||
mlir::vector::populateVectorMultiReductionReorderPatterns(
|
||||
patterns, this->loweringStrategy);
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
@ -552,10 +519,9 @@ struct LowerVectorMultiReductionPass
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorMultiReductionReorderAndExpandPatterns(
|
||||
void mlir::vector::populateVectorMultiReductionReorderPatterns(
|
||||
RewritePatternSet &patterns, VectorMultiReductionLowering options,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<OneDimMultiReductionToTwoDim>(patterns.getContext(), benefit);
|
||||
patterns.add<InnerOuterDimReductionConversion>(patterns.getContext(), options,
|
||||
benefit);
|
||||
}
|
||||
@ -569,6 +535,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
|
||||
void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
|
||||
RewritePatternSet &patterns, VectorMultiReductionLowering options,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
|
||||
if (options == VectorMultiReductionLowering ::InnerReduction)
|
||||
patterns.add<TwoDimMultiReductionToReduction>(patterns.getContext(),
|
||||
benefit);
|
||||
|
||||
@ -30,7 +30,7 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.apply_patterns to %f {
|
||||
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
|
||||
transform.apply_patterns.vector.transfer_permutation_patterns
|
||||
transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
|
||||
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
|
||||
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
|
||||
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
|
||||
transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
|
||||
|
||||
@ -39,7 +39,7 @@ module attributes {transform.with_named_sequence} {
|
||||
} : !transform.any_op
|
||||
|
||||
transform.apply_patterns to %f {
|
||||
transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
|
||||
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
|
||||
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
|
||||
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
|
||||
} : !transform.any_op
|
||||
|
||||
@ -36,50 +36,10 @@ func.func @transpose_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf3
|
||||
return %0 : vector<4xf32>
|
||||
}
|
||||
|
||||
// ALL-LABEL: func @one_dim_to_two_dim
|
||||
// ALL-SAME: %[[INPUT:.+]]: vector<8xf32>
|
||||
// ALL-SAME: %[[ACC:.+]]: f32
|
||||
func.func @one_dim_to_two_dim(%arg0: vector<8xf32>, %acc: f32) -> f32 {
|
||||
// ALL: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<8xf32> to vector<1x8xf32>
|
||||
// ALL: %[[BROADCAST:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32>
|
||||
// INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CAST]], %[[BROADCAST]] [1]
|
||||
// INNER_REDUCTION: %[[SCALAR:.+]] = vector.extract %[[RESULT]][0]
|
||||
// INNER_PARALLEL: %[[TRANSPOSED:.+]] = vector.transpose %[[CAST]], [1, 0]
|
||||
// INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[TRANSPOSED]], %[[BROADCAST]] [0]
|
||||
// INNER_PARALLEL: %[[SCALAR:.+]] = vector.extract %[[RESULT]][0]
|
||||
// ALL-LABEL: func @negative_one_dim
|
||||
func.func @negative_one_dim(%arg0: vector<8xf32>, %acc: f32) -> f32 {
|
||||
// ALL: vector.multi_reduction <add>, {{.+}} [0] : vector<8xf32> to f32
|
||||
%0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
|
||||
// ALL: return %[[SCALAR]]
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// INNER_REDUCTION-LABEL: func @one_dim_to_two_dim_scalable
|
||||
// INNER_REDUCTION-SAME: %[[INPUT:.+]]: vector<[4]xf32>
|
||||
// INNER_REDUCTION-SAME: %[[ACC:.+]]: f32
|
||||
func.func @one_dim_to_two_dim_scalable(%arg0: vector<[4]xf32>, %acc: f32) -> f32 {
|
||||
// INNER_REDUCTION: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<[4]xf32> to vector<1x[4]xf32>
|
||||
// INNER_REDUCTION: %[[BROADCAST:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32>
|
||||
// INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CAST]], %[[BROADCAST]] [1]
|
||||
%0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<[4]xf32> to f32
|
||||
// INNER_REDUCTION: %[[EXTRACT:.+]] = vector.extract %[[RESULT]][0]
|
||||
// INNER_REDUCTION: return %[[EXTRACT]]
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// INNER_REDUCTION-LABEL: func @one_dim_to_two_dim_masked
|
||||
// INNER_REDUCTION-SAME: %[[INPUT:.+]]: vector<8xf32>
|
||||
// INNER_REDUCTION-SAME: %[[ACC:.+]]: f32
|
||||
// INNER_REDUCTION-SAME: %[[MASK:.+]]: vector<8xi1>
|
||||
func.func @one_dim_to_two_dim_masked(%arg0: vector<8xf32>, %acc: f32, %mask: vector<8xi1>) -> f32 {
|
||||
// INNER_REDUCTION: %[[CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<8xf32> to vector<1x8xf32>
|
||||
// INNER_REDUCTION: %[[BROADCAST_ACC:.+]] = vector.broadcast %[[ACC]] : f32 to vector<1xf32>
|
||||
// INNER_REDUCTION: %[[BROADCAST_MASK:.+]] = vector.broadcast %[[MASK]] : vector<8xi1> to vector<1x8xi1>
|
||||
// INNER_REDUCTION: %[[RESULT:.+]] = vector.mask %[[BROADCAST_MASK]] {
|
||||
// INNER_REDUCTION: vector.multi_reduction <add>, %[[CAST]], %[[BROADCAST_ACC]] [1]
|
||||
%0 = vector.mask %mask {
|
||||
vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
|
||||
} : vector<8xi1> -> f32
|
||||
// INNER_REDUCTION: %[[EXTRACT:.+]] = vector.extract %[[RESULT]][0]
|
||||
// INNER_REDUCTION: return %[[EXTRACT]]
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
@ -87,7 +47,7 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
|
||||
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
|
||||
transform.apply_patterns to %func_op {
|
||||
transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
|
||||
} : !transform.op<"func.func">
|
||||
transform.yield
|
||||
}
|
||||
@ -95,7 +55,7 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
|
||||
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
|
||||
transform.apply_patterns to %func_op {
|
||||
transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel"
|
||||
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
|
||||
} : !transform.op<"func.func">
|
||||
transform.yield
|
||||
}
|
||||
@ -1,15 +1,32 @@
|
||||
// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefixes=INNER_REDUCTION,ALL
|
||||
// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefixes=INNER_PARALLEL,ALL
|
||||
|
||||
// ALL-LABEL: func @negative_rank1_and_rank3
|
||||
func.func @negative_rank1_and_rank3(
|
||||
%rank1: vector<8xf32>, %rank1_acc: f32,
|
||||
%rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> (f32, vector<2x3xf32>) {
|
||||
// ALL: vector.multi_reduction <add>, {{.+}} [0] : vector<8xf32> to f32
|
||||
%0 = vector.multi_reduction <add>, %rank1, %rank1_acc [0] : vector<8xf32> to f32
|
||||
// ALL-LABEL: func @one_dim_reduction
|
||||
// ALL-SAME: %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32
|
||||
func.func @one_dim_reduction(%arg0: vector<8xf32>, %acc: f32) -> f32 {
|
||||
// ALL: %[[RESULT:.+]] = vector.reduction <add>, %[[INPUT]], %[[ACC]] : vector<8xf32> into f32
|
||||
%0 = vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
|
||||
// ALL: return %[[RESULT]]
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// ALL-LABEL: func @one_dim_reduction_masked
|
||||
// ALL-SAME: %[[INPUT:.+]]: vector<8xf32>, %[[ACC:.+]]: f32, %[[MASK:.+]]: vector<8xi1>
|
||||
func.func @one_dim_reduction_masked(%arg0: vector<8xf32>, %acc: f32, %mask: vector<8xi1>) -> f32 {
|
||||
// ALL: %[[RESULT:.+]] = vector.mask %[[MASK]] { vector.reduction <add>, %[[INPUT]], %[[ACC]] : vector<8xf32> into f32 } : vector<8xi1> -> f32
|
||||
%0 = vector.mask %mask {
|
||||
vector.multi_reduction <add>, %arg0, %acc [0] : vector<8xf32> to f32
|
||||
} : vector<8xi1> -> f32
|
||||
// ALL: return %[[RESULT]]
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// ALL-LABEL: func @negative_rank3
|
||||
func.func @negative_rank3(
|
||||
%rank3: vector<2x3x4xf32>, %rank3_acc: vector<2x3xf32>) -> vector<2x3xf32> {
|
||||
// ALL: vector.multi_reduction <add>, {{.+}} [2] : vector<2x3x4xf32> to vector<2x3xf32>
|
||||
%1 = vector.multi_reduction <add>, %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32>
|
||||
return %0, %1 : f32, vector<2x3xf32>
|
||||
%0 = vector.multi_reduction <add>, %rank3, %rank3_acc [2] : vector<2x3x4xf32> to vector<2x3xf32>
|
||||
return %0 : vector<2x3xf32>
|
||||
}
|
||||
|
||||
// ALL-LABEL: func @inner_reduction_2d
|
||||
|
||||
@ -150,7 +150,7 @@ module attributes {transform.with_named_sequence} {
|
||||
// Step 3: Lower vector.multi_reduction
|
||||
transform.apply_patterns to %func {
|
||||
transform.apply_patterns.vector.lower_masked_transfers
|
||||
transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
|
||||
} : !transform.op<"func.func">
|
||||
|
||||
@ -155,7 +155,7 @@ module attributes {transform.with_named_sequence} {
|
||||
// Step 3: Lower vector.multi_reduction
|
||||
transform.apply_patterns to %func {
|
||||
transform.apply_patterns.vector.lower_masked_transfers
|
||||
transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
|
||||
} : !transform.op<"func.func">
|
||||
|
||||
@ -53,7 +53,7 @@ module attributes {transform.with_named_sequence} {
|
||||
%func_op = transform.get_parent_op %0 : (!transform.any_op) -> !transform.op<"func.func">
|
||||
transform.structured.vectorize %0 vector_sizes [4, 4, 2] : !transform.any_op
|
||||
transform.apply_patterns to %func_op {
|
||||
transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
|
||||
transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
|
||||
} : !transform.op<"func.func">
|
||||
|
||||
@ -87,11 +87,11 @@ def enum_configurable_patterns():
|
||||
lowering_strategy=vector.VectorContractLowering.ParallelArith
|
||||
)
|
||||
|
||||
# CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims
|
||||
vector.ApplyReorderAndExpandMultiReductionPatternsOp()
|
||||
# CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims
|
||||
# CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims
|
||||
vector.ApplyReorderMultiReductionPatternsOp()
|
||||
# CHECK: transform.apply_patterns.vector.reorder_multi_reduction_dims
|
||||
# CHECK-SAME: lowering_strategy = innerreduction
|
||||
vector.ApplyReorderAndExpandMultiReductionPatternsOp(
|
||||
vector.ApplyReorderMultiReductionPatternsOp(
|
||||
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user