[MLIR][XeGPU] Support leading unit dims in vector.multi_reduction in sg to wi pass (#188767)

This PR adds support for transforming vector.multi_reduction with
vectors > rank 2d with leading unit dims
This commit is contained in:
Nishant Patel 2026-03-30 09:29:20 -07:00 committed by GitHub
parent acb3d81a93
commit ad4d4c0f63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 96 additions and 18 deletions

View File

@ -588,6 +588,17 @@ struct SgToWiMultiDimReduction
assert(reductionDims.size() == 1 &&
"Expecting single reduction dimension for subgroup multi "
"reduction op");
// For rank > 2, ensure leading dimensions are unit.
VectorType sourceType = op.getSourceVectorType();
int64_t rank = sourceType.getRank();
if (rank > 2) {
ArrayRef<int64_t> shape = sourceType.getShape();
if (llvm::any_of(shape.take_front(rank - 2),
[](int64_t d) { return d != 1; }))
return rewriter.notifyMatchFailure(
op, "only unit leading dimensions are supported for "
"multi_reduction with rank > 2");
}
if (isReductionLaneLocal(op)) {
auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
VectorType resVecTy = dyn_cast<VectorType>(op.getType());

View File

@ -750,11 +750,18 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
TypedValue<VectorType> src, TypedValue<VectorType> acc,
vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize,
Location loc, PatternRewriter &rewriter) {
// Expecting a 2D source vector.
assert(src.getType().getRank() == 2 && "expected a 2D source vector");
VectorType sourceType = src.getType();
int64_t sourceH = sourceType.getShape()[0];
int64_t sourceW = sourceType.getShape()[1];
int64_t sourceRank = sourceType.getRank();
// Expecting at least a 2D source vector. Leading dimensions (all except the
// last two) must be unit.
assert(sourceRank >= 2 && "expected at least a 2D source vector");
for (int64_t i = 0; i < sourceRank - 2; ++i)
assert(sourceType.getShape()[i] == 1 &&
"expected leading dimensions to be unit");
int64_t rowIdx = sourceRank - 2;
int64_t columnIdx = sourceRank - 1;
int64_t sourceH = sourceType.getShape()[rowIdx];
int64_t sourceW = sourceType.getShape()[columnIdx];
// Create a constant vector to hold the result of the reduction.
TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
@ -763,39 +770,46 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
DenseElementsAttr::get(acc.getType(), zeroAttr));
// nSlices is the number of reduction operations needed to reduce the entire
// source vector. For example, if reductionDim is 0, we are reducing across
// rows, and each slice is a column of the source vector. So the number of
// slices is the number of columns, which is sourceW.
int nSlices = (reductionDim == 0) ? sourceW : sourceH;
// source vector. For example, if reductionDim is the row dim, we are
// reducing across rows, and each slice is a column. So the number of slices
// is the number of columns, which is sourceW.
int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
// For each slice of the source, extract the slice vector, do a reduction
// and, insert the reduced value back to the result vector.
int64_t accRank = acc.getType().getRank();
for (int i = 0; i < nSlices; ++i) {
SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
if (reductionDim == 1) {
sliceOffsets = {i, 0};
sliceSizes = {1, sourceW};
// Build nD offsets, sizes, and strides. Leading unit dims get
// offset=0, size=1. The last two dims are set based on reductionDim.
SmallVector<int64_t> sliceOffsets(sourceRank, 0);
SmallVector<int64_t> sliceSizes(sourceRank, 1);
SmallVector<int64_t> strides(sourceRank, 1);
if (reductionDim == columnIdx) {
sliceOffsets[rowIdx] = i;
sliceSizes[columnIdx] = sourceW;
} else {
sliceOffsets = {0, i};
sliceSizes = {sourceH, 1};
sliceOffsets[columnIdx] = i;
sliceSizes[rowIdx] = sourceH;
}
vector::ExtractStridedSliceOp extractOp =
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
sliceSizes, {1, 1});
sliceSizes, strides);
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
vector::ShapeCastOp slice = vector::ShapeCastOp::create(
rewriter, loc,
VectorType::get({nSliceElements}, sourceType.getElementType()),
extractOp.getResult());
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
SmallVector<int64_t> accIdx(accRank, 0);
accIdx[accRank - 1] = i;
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
Value fullReduce =
xegpu::subgroupReduction(loc, rewriter, slice, kind, reductionSize);
fullReduce =
vector::makeArithReduction(rewriter, loc, kind, fullReduce, accExtract);
reductionResult =
vector::InsertOp::create(rewriter, loc, fullReduce, reductionResult, i);
reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce,
reductionResult, accIdx);
}
return reductionResult;
}

View File

@ -461,6 +461,59 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
gpu.return
}
// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32>
// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x16x2xf32> to vector<1x2xf32>
// CHECK: gpu.return
gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() {
%src = arith.constant
{layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
dense<0.0> : vector<1x16x32xf32>
%acc = arith.constant
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>}
dense<0.0> : vector<1x32xf32>
%1 = vector.multi_reduction <add>, %src, %acc
{
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
}
[1] : vector<1x16x32xf32> to vector<1x32xf32>
gpu.return
}
// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane
// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x2xf32>
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
// CHECK: vector.extract_strided_slice %[[SRC]]
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [1, 1, 1], strides = [1, 1, 1]}
// CHECK: %[[ACC0:.*]] = vector.extract %[[ACC]][0, 0] : f32 from vector<1x2xf32>
// CHECK: vector.reduction <add>, %{{.*}} : vector<1xf32> into f32
// CHECK-COUNT-4: gpu.shuffle xor %{{.*}} : f32
// CHECK: %[[WITH_ACC0:.*]] = arith.addf %{{.*}}, %[[ACC0]] : f32
// CHECK: %[[INS0:.*]] = vector.insert %[[WITH_ACC0]], %{{.*}} [0, 0] : f32 into vector<1x2xf32>
// CHECK: vector.extract_strided_slice %[[SRC]]
// CHECK-SAME: {offsets = [0, 0, 1], sizes = [1, 1, 1], strides = [1, 1, 1]}
// CHECK: %[[ACC1:.*]] = vector.extract %[[ACC]][0, 1] : f32 from vector<1x2xf32>
// CHECK: vector.reduction <add>, %{{.*}} : vector<1xf32> into f32
// CHECK-COUNT-4: gpu.shuffle xor %{{.*}} : f32
// CHECK: %[[WITH_ACC1:.*]] = arith.addf %{{.*}}, %[[ACC1]] : f32
// CHECK: vector.insert %[[WITH_ACC1]], %[[INS0]] [0, 1] : f32 into vector<1x2xf32>
// CHECK: gpu.return
gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane() {
%src = arith.constant
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>}
dense<0.0> : vector<1x16x2xf32>
%acc = arith.constant
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>}
dense<0.0> : vector<1x2xf32>
%1 = vector.multi_reduction <add>, %src, %acc
{
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>
}
[1] : vector<1x16x2xf32> to vector<1x2xf32>
gpu.return
}
// CHECK-LABEL: gpu.func @vector_extract_from_2d
// CHECK: %[[EXT:.*]] = vector.extract %{{.*}}[0] : vector<1xf32> from vector<4x1xf32>
gpu.func @vector_extract_from_2d() {