[MLIR][XeGPU] Support leading unit dims in vector.multi_reduction in sg to wi pass (#188767)
This PR adds support for transforming vector.multi_reduction with vectors > rank 2d with leading unit dims
This commit is contained in:
parent
acb3d81a93
commit
ad4d4c0f63
@ -588,6 +588,17 @@ struct SgToWiMultiDimReduction
|
||||
assert(reductionDims.size() == 1 &&
|
||||
"Expecting single reduction dimension for subgroup multi "
|
||||
"reduction op");
|
||||
// For rank > 2, ensure leading dimensions are unit.
|
||||
VectorType sourceType = op.getSourceVectorType();
|
||||
int64_t rank = sourceType.getRank();
|
||||
if (rank > 2) {
|
||||
ArrayRef<int64_t> shape = sourceType.getShape();
|
||||
if (llvm::any_of(shape.take_front(rank - 2),
|
||||
[](int64_t d) { return d != 1; }))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only unit leading dimensions are supported for "
|
||||
"multi_reduction with rank > 2");
|
||||
}
|
||||
if (isReductionLaneLocal(op)) {
|
||||
auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0));
|
||||
VectorType resVecTy = dyn_cast<VectorType>(op.getType());
|
||||
|
||||
@ -750,11 +750,18 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
|
||||
TypedValue<VectorType> src, TypedValue<VectorType> acc,
|
||||
vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize,
|
||||
Location loc, PatternRewriter &rewriter) {
|
||||
// Expecting a 2D source vector.
|
||||
assert(src.getType().getRank() == 2 && "expected a 2D source vector");
|
||||
VectorType sourceType = src.getType();
|
||||
int64_t sourceH = sourceType.getShape()[0];
|
||||
int64_t sourceW = sourceType.getShape()[1];
|
||||
int64_t sourceRank = sourceType.getRank();
|
||||
// Expecting at least a 2D source vector. Leading dimensions (all except the
|
||||
// last two) must be unit.
|
||||
assert(sourceRank >= 2 && "expected at least a 2D source vector");
|
||||
for (int64_t i = 0; i < sourceRank - 2; ++i)
|
||||
assert(sourceType.getShape()[i] == 1 &&
|
||||
"expected leading dimensions to be unit");
|
||||
int64_t rowIdx = sourceRank - 2;
|
||||
int64_t columnIdx = sourceRank - 1;
|
||||
int64_t sourceH = sourceType.getShape()[rowIdx];
|
||||
int64_t sourceW = sourceType.getShape()[columnIdx];
|
||||
|
||||
// Create a constant vector to hold the result of the reduction.
|
||||
TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType());
|
||||
@ -763,39 +770,46 @@ Value xegpu::lowerCrossLaneReductionToShuffles(
|
||||
DenseElementsAttr::get(acc.getType(), zeroAttr));
|
||||
|
||||
// nSlices is the number of reduction operations needed to reduce the entire
|
||||
// source vector. For example, if reductionDim is 0, we are reducing across
|
||||
// rows, and each slice is a column of the source vector. So the number of
|
||||
// slices is the number of columns, which is sourceW.
|
||||
int nSlices = (reductionDim == 0) ? sourceW : sourceH;
|
||||
// source vector. For example, if reductionDim is the row dim, we are
|
||||
// reducing across rows, and each slice is a column. So the number of slices
|
||||
// is the number of columns, which is sourceW.
|
||||
int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
|
||||
|
||||
// For each slice of the source, extract the slice vector, do a reduction
|
||||
// and, insert the reduced value back to the result vector.
|
||||
int64_t accRank = acc.getType().getRank();
|
||||
for (int i = 0; i < nSlices; ++i) {
|
||||
SmallVector<int64_t, 2> sliceOffsets, sliceSizes;
|
||||
if (reductionDim == 1) {
|
||||
sliceOffsets = {i, 0};
|
||||
sliceSizes = {1, sourceW};
|
||||
// Build nD offsets, sizes, and strides. Leading unit dims get
|
||||
// offset=0, size=1. The last two dims are set based on reductionDim.
|
||||
SmallVector<int64_t> sliceOffsets(sourceRank, 0);
|
||||
SmallVector<int64_t> sliceSizes(sourceRank, 1);
|
||||
SmallVector<int64_t> strides(sourceRank, 1);
|
||||
if (reductionDim == columnIdx) {
|
||||
sliceOffsets[rowIdx] = i;
|
||||
sliceSizes[columnIdx] = sourceW;
|
||||
} else {
|
||||
sliceOffsets = {0, i};
|
||||
sliceSizes = {sourceH, 1};
|
||||
sliceOffsets[columnIdx] = i;
|
||||
sliceSizes[rowIdx] = sourceH;
|
||||
}
|
||||
|
||||
vector::ExtractStridedSliceOp extractOp =
|
||||
vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
|
||||
sliceSizes, {1, 1});
|
||||
sliceSizes, strides);
|
||||
int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
|
||||
vector::ShapeCastOp slice = vector::ShapeCastOp::create(
|
||||
rewriter, loc,
|
||||
VectorType::get({nSliceElements}, sourceType.getElementType()),
|
||||
extractOp.getResult());
|
||||
|
||||
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i);
|
||||
SmallVector<int64_t> accIdx(accRank, 0);
|
||||
accIdx[accRank - 1] = i;
|
||||
Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx);
|
||||
Value fullReduce =
|
||||
xegpu::subgroupReduction(loc, rewriter, slice, kind, reductionSize);
|
||||
fullReduce =
|
||||
vector::makeArithReduction(rewriter, loc, kind, fullReduce, accExtract);
|
||||
reductionResult =
|
||||
vector::InsertOp::create(rewriter, loc, fullReduce, reductionResult, i);
|
||||
reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce,
|
||||
reductionResult, accIdx);
|
||||
}
|
||||
return reductionResult;
|
||||
}
|
||||
|
||||
@ -461,6 +461,59 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index)
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local
|
||||
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32>
|
||||
// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
|
||||
// CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[CST]], %[[CST_0]] [1] : vector<1x16x2xf32> to vector<1x2xf32>
|
||||
// CHECK: gpu.return
|
||||
gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() {
|
||||
%src = arith.constant
|
||||
{layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>}
|
||||
dense<0.0> : vector<1x16x32xf32>
|
||||
%acc = arith.constant
|
||||
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>}
|
||||
dense<0.0> : vector<1x32xf32>
|
||||
%1 = vector.multi_reduction <add>, %src, %acc
|
||||
{
|
||||
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1]>
|
||||
}
|
||||
[1] : vector<1x16x32xf32> to vector<1x32xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane
|
||||
// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x2xf32>
|
||||
// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32>
|
||||
// CHECK: vector.extract_strided_slice %[[SRC]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [1, 1, 1], strides = [1, 1, 1]}
|
||||
// CHECK: %[[ACC0:.*]] = vector.extract %[[ACC]][0, 0] : f32 from vector<1x2xf32>
|
||||
// CHECK: vector.reduction <add>, %{{.*}} : vector<1xf32> into f32
|
||||
// CHECK-COUNT-4: gpu.shuffle xor %{{.*}} : f32
|
||||
// CHECK: %[[WITH_ACC0:.*]] = arith.addf %{{.*}}, %[[ACC0]] : f32
|
||||
// CHECK: %[[INS0:.*]] = vector.insert %[[WITH_ACC0]], %{{.*}} [0, 0] : f32 into vector<1x2xf32>
|
||||
// CHECK: vector.extract_strided_slice %[[SRC]]
|
||||
// CHECK-SAME: {offsets = [0, 0, 1], sizes = [1, 1, 1], strides = [1, 1, 1]}
|
||||
// CHECK: %[[ACC1:.*]] = vector.extract %[[ACC]][0, 1] : f32 from vector<1x2xf32>
|
||||
// CHECK: vector.reduction <add>, %{{.*}} : vector<1xf32> into f32
|
||||
// CHECK-COUNT-4: gpu.shuffle xor %{{.*}} : f32
|
||||
// CHECK: %[[WITH_ACC1:.*]] = arith.addf %{{.*}}, %[[ACC1]] : f32
|
||||
// CHECK: vector.insert %[[WITH_ACC1]], %[[INS0]] [0, 1] : f32 into vector<1x2xf32>
|
||||
// CHECK: gpu.return
|
||||
gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane() {
|
||||
%src = arith.constant
|
||||
{layout_result_0 = #xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>}
|
||||
dense<0.0> : vector<1x16x2xf32>
|
||||
%acc = arith.constant
|
||||
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>}
|
||||
dense<0.0> : vector<1x2xf32>
|
||||
%1 = vector.multi_reduction <add>, %src, %acc
|
||||
{
|
||||
layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16, 1], lane_data = [1, 1, 1]>, dims = [1]>
|
||||
}
|
||||
[1] : vector<1x16x2xf32> to vector<1x2xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: gpu.func @vector_extract_from_2d
|
||||
// CHECK: %[[EXT:.*]] = vector.extract %{{.*}}[0] : vector<1xf32> from vector<4x1xf32>
|
||||
gpu.func @vector_extract_from_2d() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user