[MLIR][Vector] Enhance vector.multi_reduction unrolling to handle scalar result (#188633)
Previously, UnrollMultiReductionPattern bailed out when all the dimensions were reduced to a scalar. This PR adds support for this case by tiling the source vector and chaining partial reductions through the accumulator operand.
This commit is contained in:
parent
1a1fbf967a
commit
b3ca423a78
@ -381,21 +381,40 @@ struct UnrollMultiReductionPattern
|
||||
|
||||
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto resultType = reductionOp->getResult(0).getType();
|
||||
if (resultType.isIntOrFloat()) {
|
||||
return rewriter.notifyMatchFailure(reductionOp,
|
||||
"Unrolling scalars is not supported");
|
||||
}
|
||||
std::optional<SmallVector<int64_t>> targetShape =
|
||||
getTargetShape(options, reductionOp);
|
||||
if (!targetShape)
|
||||
return failure();
|
||||
SmallVector<int64_t> originalSize = *reductionOp.getShapeForUnroll();
|
||||
Location loc = reductionOp.getLoc();
|
||||
auto resultType = reductionOp->getResult(0).getType();
|
||||
|
||||
// Handle scalar result case: all dimensions are reduced.
|
||||
// Each source tile is reduced to a scalar, and partial results are
|
||||
// chained through the accumulator operand.
|
||||
if (resultType.isIntOrFloat()) {
|
||||
Value accumulator = reductionOp.getAcc();
|
||||
for (SmallVector<int64_t> offsets :
|
||||
StaticTileOffsetRange(originalSize, *targetShape)) {
|
||||
SmallVector<int64_t> operandStrides(offsets.size(), 1);
|
||||
Value slicedOperand =
|
||||
rewriter.createOrFold<vector::ExtractStridedSliceOp>(
|
||||
loc, reductionOp.getSource(), offsets, *targetShape,
|
||||
operandStrides);
|
||||
Operation *newOp = cloneOpWithOperandsAndTypes(
|
||||
rewriter, loc, reductionOp, {slicedOperand, accumulator},
|
||||
resultType);
|
||||
accumulator = newOp->getResult(0);
|
||||
}
|
||||
rewriter.replaceOp(reductionOp, accumulator);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Vector result case.
|
||||
llvm::MapVector<
|
||||
SmallVector<int64_t>, Value,
|
||||
llvm::DenseMap<SmallVector<int64_t>, unsigned, OffsetMapInfo>>
|
||||
accCache;
|
||||
Location loc = reductionOp.getLoc();
|
||||
|
||||
// Stride of the ratios, this gives us the offsets of sliceCount in a basis
|
||||
// of multiples of the targetShape.
|
||||
|
||||
@ -245,15 +245,17 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) ->
|
||||
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
|
||||
// CHECK: return %[[V2]] : vector<4xf32>
|
||||
|
||||
// This is a negative test case to ensure that further unrolling is not performed. Since the vector.multi_reduction
|
||||
// operation has already been unrolled, attempting additional unrolling should not be allowed.
|
||||
func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 {
|
||||
func.func @vector_multi_reduction_scalar(%v: vector<4x2xf32>, %acc: f32) -> f32 {
|
||||
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [0, 1] : vector<4x2xf32> to f32
|
||||
return %0 : f32
|
||||
}
|
||||
// CHECK-LABEL: func @negative_vector_multi_reduction
|
||||
// CHECK-NEXT: %[[R0:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32
|
||||
// CHECK-NEXT: return %[[R0]] : f32
|
||||
// CHECK-LABEL: func @vector_multi_reduction_scalar
|
||||
// CHECK-SAME: %[[V:.*]]: vector<4x2xf32>, %[[ACC:.*]]: f32
|
||||
// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[R0:.*]] = vector.multi_reduction <add>, %[[S0]], %[[ACC]] [0, 1] : vector<2x2xf32> to f32
|
||||
// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
|
||||
// CHECK: %[[R1:.*]] = vector.multi_reduction <add>, %[[S1]], %[[R0]] [0, 1] : vector<2x2xf32> to f32
|
||||
// CHECK: return %[[R1]] : f32
|
||||
|
||||
func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
|
||||
%0 = vector.reduction <add>, %v : vector<8xf32> into f32
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user