diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp index 8c60ced4ed38..ccac78eb6d9d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp @@ -588,6 +588,17 @@ struct SgToWiMultiDimReduction assert(reductionDims.size() == 1 && "Expecting single reduction dimension for subgroup multi " "reduction op"); + // For rank > 2, ensure leading dimensions are unit. + VectorType sourceType = op.getSourceVectorType(); + int64_t rank = sourceType.getRank(); + if (rank > 2) { + ArrayRef shape = sourceType.getShape(); + if (llvm::any_of(shape.take_front(rank - 2), + [](int64_t d) { return d != 1; })) + return rewriter.notifyMatchFailure( + op, "only unit leading dimensions are supported for " + "multi_reduction with rank > 2"); + } if (isReductionLaneLocal(op)) { auto resLayout = xegpu::getTemporaryLayout(op->getOpResult(0)); VectorType resVecTy = dyn_cast(op.getType()); diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 6c902f725ca0..b3e40b317289 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -750,11 +750,18 @@ Value xegpu::lowerCrossLaneReductionToShuffles( TypedValue src, TypedValue acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter) { - // Expecting a 2D source vector. - assert(src.getType().getRank() == 2 && "expected a 2D source vector"); VectorType sourceType = src.getType(); - int64_t sourceH = sourceType.getShape()[0]; - int64_t sourceW = sourceType.getShape()[1]; + int64_t sourceRank = sourceType.getRank(); + // Expecting at least a 2D source vector. Leading dimensions (all except the + // last two) must be unit. + assert(sourceRank >= 2 && "expected at least a 2D source vector"); + for (int64_t i = 0; i < sourceRank - 2; ++i) + assert(sourceType.getShape()[i] == 1 && + "expected leading dimensions to be unit"); + int64_t rowIdx = sourceRank - 2; + int64_t columnIdx = sourceRank - 1; + int64_t sourceH = sourceType.getShape()[rowIdx]; + int64_t sourceW = sourceType.getShape()[columnIdx]; // Create a constant vector to hold the result of the reduction. TypedAttr zeroAttr = rewriter.getZeroAttr(sourceType.getElementType()); @@ -763,39 +770,46 @@ Value xegpu::lowerCrossLaneReductionToShuffles( DenseElementsAttr::get(acc.getType(), zeroAttr)); // nSlices is the number of reduction operations needed to reduce the entire - // source vector. For example, if reductionDim is 0, we are reducing across - // rows, and each slice is a column of the source vector. So the number of - // slices is the number of columns, which is sourceW. - int nSlices = (reductionDim == 0) ? sourceW : sourceH; + // source vector. For example, if reductionDim is the row dim, we are + // reducing across rows, and each slice is a column. So the number of slices + // is the number of columns, which is sourceW. + int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH; // For each slice of the source, extract the slice vector, do a reduction // and, insert the reduced value back to the result vector. + int64_t accRank = acc.getType().getRank(); for (int i = 0; i < nSlices; ++i) { - SmallVector sliceOffsets, sliceSizes; - if (reductionDim == 1) { - sliceOffsets = {i, 0}; - sliceSizes = {1, sourceW}; + // Build nD offsets, sizes, and strides. Leading unit dims get + // offset=0, size=1. The last two dims are set based on reductionDim. + SmallVector sliceOffsets(sourceRank, 0); + SmallVector sliceSizes(sourceRank, 1); + SmallVector strides(sourceRank, 1); + if (reductionDim == columnIdx) { + sliceOffsets[rowIdx] = i; + sliceSizes[columnIdx] = sourceW; } else { - sliceOffsets = {0, i}; - sliceSizes = {sourceH, 1}; + sliceOffsets[columnIdx] = i; + sliceSizes[rowIdx] = sourceH; } vector::ExtractStridedSliceOp extractOp = vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets, - sliceSizes, {1, 1}); + sliceSizes, strides); int64_t nSliceElements = extractOp.getResult().getType().getNumElements(); vector::ShapeCastOp slice = vector::ShapeCastOp::create( rewriter, loc, VectorType::get({nSliceElements}, sourceType.getElementType()), extractOp.getResult()); - Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, i); + SmallVector accIdx(accRank, 0); + accIdx[accRank - 1] = i; + Value accExtract = vector::ExtractOp::create(rewriter, loc, acc, accIdx); Value fullReduce = xegpu::subgroupReduction(loc, rewriter, slice, kind, reductionSize); fullReduce = vector::makeArithReduction(rewriter, loc, kind, fullReduce, accExtract); - reductionResult = - vector::InsertOp::create(rewriter, loc, fullReduce, reductionResult, i); + reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce, + reductionResult, accIdx); } return reductionResult; } diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir index 9c4f469ea475..303214e54403 100644 --- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir +++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir @@ -461,6 +461,59 @@ gpu.func @vector_multi_reduction_dim0_distributed_dim1_reduction(%laneid: index) gpu.return } +// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x16x2xf32> +// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32> +// CHECK: %[[RED:.*]] = vector.multi_reduction , %[[CST]], %[[CST_0]] [1] : vector<1x16x2xf32> to vector<1x2xf32> +// CHECK: gpu.return +gpu.func @vector_multi_reduction_3d_leading_unit_dim_lane_local() { + %src = arith.constant + {layout_result_0 = #xegpu.layout} + dense<0.0> : vector<1x16x32xf32> + %acc = arith.constant + {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} + dense<0.0> : vector<1x32xf32> + %1 = vector.multi_reduction , %src, %acc + { + layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]> + } + [1] : vector<1x16x32xf32> to vector<1x32xf32> + gpu.return +} + +// CHECK-LABEL: gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane +// CHECK-DAG: %[[SRC:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x2xf32> +// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<1x2xf32> +// CHECK: vector.extract_strided_slice %[[SRC]] +// CHECK-SAME: {offsets = [0, 0, 0], sizes = [1, 1, 1], strides = [1, 1, 1]} +// CHECK: %[[ACC0:.*]] = vector.extract %[[ACC]][0, 0] : f32 from vector<1x2xf32> +// CHECK: vector.reduction , %{{.*}} : vector<1xf32> into f32 +// CHECK-COUNT-4: gpu.shuffle xor %{{.*}} : f32 +// CHECK: %[[WITH_ACC0:.*]] = arith.addf %{{.*}}, %[[ACC0]] : f32 +// CHECK: %[[INS0:.*]] = vector.insert %[[WITH_ACC0]], %{{.*}} [0, 0] : f32 into vector<1x2xf32> +// CHECK: vector.extract_strided_slice %[[SRC]] +// CHECK-SAME: {offsets = [0, 0, 1], sizes = [1, 1, 1], strides = [1, 1, 1]} +// CHECK: %[[ACC1:.*]] = vector.extract %[[ACC]][0, 1] : f32 from vector<1x2xf32> +// CHECK: vector.reduction , %{{.*}} : vector<1xf32> into f32 +// CHECK-COUNT-4: gpu.shuffle xor %{{.*}} : f32 +// CHECK: %[[WITH_ACC1:.*]] = arith.addf %{{.*}}, %[[ACC1]] : f32 +// CHECK: vector.insert %[[WITH_ACC1]], %[[INS0]] [0, 1] : f32 into vector<1x2xf32> +// CHECK: gpu.return +gpu.func @vector_multi_reduction_3d_leading_unit_dim_cross_lane() { + %src = arith.constant + {layout_result_0 = #xegpu.layout} + dense<0.0> : vector<1x16x2xf32> + %acc = arith.constant + {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]>} + dense<0.0> : vector<1x2xf32> + %1 = vector.multi_reduction , %src, %acc + { + layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1]> + } + [1] : vector<1x16x2xf32> to vector<1x2xf32> + gpu.return +} + // CHECK-LABEL: gpu.func @vector_extract_from_2d // CHECK: %[[EXT:.*]] = vector.extract %{{.*}}[0] : vector<1xf32> from vector<4x1xf32> gpu.func @vector_extract_from_2d() {