diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h index e7cae506d9f4..5a806799e896 100644 --- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h +++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h @@ -147,6 +147,14 @@ Value lowerToVectorReductions(TypedValue src, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter); +/// Creates a constant vector filled with the neutral (identity) value for the +/// given reduction kind. For example: 0 for ADD/OR/XOR, 1 for MUL/AND, +/// max/min signed/unsigned int for MINSI/MINUI/MAXSI/MAXUI, and +/-infinity +/// for float min/max operations. Returns nullptr if the element type is +/// incompatible with the requested reduction kind. +Value createReductionNeutralValue(OpBuilder &builder, Location loc, + VectorType type, vector::CombiningKind kind); + /// Lowers cross-lane reductions to shuffle operations on a 2D vector. /// Extracts slices along the reduction dimension, performs subgroup reductions /// with shuffles across reductionSize work-items, and inserts the results back diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp index d7a9b7ba377f..0ece695aed51 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp @@ -428,10 +428,8 @@ class MultiRed2dOpPattern matchAndRewrite(vector::MultiDimReductionOp reductionOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto sourceVecType = reductionOp.getSourceVectorType(); - if (reductionOp.getReductionDims().size() != 2 || - sourceVecType.getRank() != 2) - return rewriter.notifyMatchFailure( - reductionOp, "Expected 2D multi reduction of a 2D source"); + if (reductionOp.getReductionDims().size() != 2) + return rewriter.notifyMatchFailure(reductionOp, "Expected 2D reduction"); auto resLayout = xegpu::getDistributeLayoutAttr(reductionOp.getResult()); // Retrieve and order dims for 1D decomposition (prefer intra-lane first). auto dims = llvm::to_vector(reductionOp.getReductionDims()); @@ -444,33 +442,22 @@ class MultiRed2dOpPattern auto loc = reductionOp.getLoc(); auto acc = reductionOp.getAcc(); - // The first reduction's dist attribute does not have the cross lane dim. - auto resSliceLayoutAttr = cast(resLayout); - SmallVector dropDims{crossLaneDim}; - auto intraLaneRedResLayout = resSliceLayoutAttr.dropSliceDims(dropDims); - SmallVector accShape(sourceVecType.getShape()); accShape.erase(accShape.begin() + intraLaneDim); - if (acc) { - acc = vector::BroadcastOp::create( - rewriter, loc, - VectorType::get(accShape, sourceVecType.getElementType()), acc); - xegpu::setDistributeLayoutAttr( - llvm::dyn_cast(acc), - cast(intraLaneRedResLayout)); - } - Value intraLaneReduced = vector::MultiDimReductionOp::create( - rewriter, loc, reductionOp.getKind(), reductionOp.getSource(), acc, - ArrayRef(intraLaneDim)); - xegpu::setDistributeLayoutAttr( - llvm::dyn_cast(intraLaneReduced), - cast(intraLaneRedResLayout)); + Type eTy = sourceVecType.getElementType(); + Value constNeutralVal = xegpu::createReductionNeutralValue( + rewriter, loc, VectorType::get(accShape, eTy), reductionOp.getKind()); - Value crossLaneReduced = vector::ReductionOp::create( - rewriter, loc, reductionOp.getKind(), intraLaneReduced, nullptr); - xegpu::setDistributeLayoutAttr( - llvm::dyn_cast(crossLaneReduced), - cast(resLayout)); + Value intraLaneReduced = vector::MultiDimReductionOp::create( + rewriter, loc, reductionOp.getKind(), reductionOp.getSource(), + constNeutralVal, ArrayRef(intraLaneDim)); + + // Adjust crossLaneDim after the first reduction. + if (crossLaneDim > intraLaneDim) + crossLaneDim -= 1; + Value crossLaneReduced = vector::MultiDimReductionOp::create( + rewriter, loc, reductionOp.getKind(), intraLaneReduced, acc, + ArrayRef(crossLaneDim)); assert(crossLaneReduced.getType() == reductionOp.getResult().getType() && "Type mismatch"); rewriter.replaceOp(reductionOp, crossLaneReduced); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9f8dbc15f642..b404ecf18984 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1198,86 +1198,6 @@ struct WgToSgVectorShapeCastOp } }; -static Value createAccumulator(ConversionPatternRewriter &rewriter, - Location loc, VectorType type, - vector::CombiningKind kind) { - Type elemTy = type.getElementType(); - - switch (kind) { - case vector::CombiningKind::ADD: - case vector::CombiningKind::XOR: - case vector::CombiningKind::OR: - return arith::ConstantOp::create( - rewriter, loc, type, - DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy))); - - case vector::CombiningKind::MUL: - case vector::CombiningKind::AND: - return arith::ConstantOp::create( - rewriter, loc, type, - DenseElementsAttr::get(type, rewriter.getOneAttr(elemTy))); - - case vector::CombiningKind::MINSI: - // Use max signed int value for signed integer min - if (auto intTy = dyn_cast(elemTy)) { - auto maxVal = APInt::getSignedMaxValue(intTy.getWidth()); - return arith::ConstantOp::create( - rewriter, loc, type, - DenseElementsAttr::get(type, - rewriter.getIntegerAttr(elemTy, maxVal))); - } - return nullptr; - - case vector::CombiningKind::MINUI: - if (auto intTy = dyn_cast(elemTy)) { - auto maxVal = APInt::getMaxValue(intTy.getWidth()); - return arith::ConstantOp::create( - rewriter, loc, type, - DenseElementsAttr::get(type, - rewriter.getIntegerAttr(elemTy, maxVal))); - } - return nullptr; - - case vector::CombiningKind::MAXSI: - if (auto intTy = dyn_cast(elemTy)) { - auto minVal = APInt::getSignedMinValue(intTy.getWidth()); - return arith::ConstantOp::create( - rewriter, loc, type, - DenseElementsAttr::get(type, - rewriter.getIntegerAttr(elemTy, minVal))); - } - return nullptr; - - case vector::CombiningKind::MAXUI: - return arith::ConstantOp::create( - rewriter, loc, type, - DenseElementsAttr::get(type, rewriter.getZeroAttr(elemTy))); - - case vector::CombiningKind::MINNUMF: - case vector::CombiningKind::MINIMUMF: - // Use +infinity for float min operations - if (auto floatTy = dyn_cast(elemTy)) { - auto posInf = APFloat::getInf(floatTy.getFloatSemantics()); - return arith::ConstantOp::create( - rewriter, loc, type, - DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, posInf))); - } - return nullptr; - - case vector::CombiningKind::MAXNUMF: - case vector::CombiningKind::MAXIMUMF: - // Use -infinity for float max operations - if (auto floatTy = dyn_cast(elemTy)) { - auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true); - return arith::ConstantOp::create( - rewriter, loc, type, - DenseElementsAttr::get(type, rewriter.getFloatAttr(elemTy, negInf))); - } - return nullptr; - } - return nullptr; -} - /// This pattern transforms vector.multi_dim_reduction operations from /// workgroup-level to subgroup-level execution with support for multiple /// reduction dimensions. @@ -1359,8 +1279,8 @@ struct WgToSgMultiDimReductionOp VectorType newDstType = VectorType::get(sgDstShape, elemTy); for (auto sgSrc : sgSrcs) { // Create ZERO accumulator for local reduction - auto neutralLocalAcc = - createAccumulator(rewriter, loc, newDstType, op.getKind()); + auto neutralLocalAcc = xegpu::createReductionNeutralValue( + rewriter, loc, newDstType, op.getKind()); // Local reduction with ZERO accumulator auto localReduce = vector::MultiDimReductionOp::create( rewriter, loc, newDstType, op.getKind(), sgSrc, neutralLocalAcc, @@ -1481,8 +1401,8 @@ struct WgToSgMultiDimReductionOp /*layout=*/nullptr); // Step 6: Perform final reduction with ZERO accumulator - auto neutralFinalAcc = - createAccumulator(rewriter, loc, newDstType, op.getKind()); + auto neutralFinalAcc = xegpu::createReductionNeutralValue( + rewriter, loc, newDstType, op.getKind()); auto finalReduce = vector::MultiDimReductionOp::create( rewriter, loc, newDstType, op.getKind(), slmLoadOp.getResult(), diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index a57bf8512dde..c30cd6edefc5 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -797,6 +797,83 @@ Value xegpu::lowerCrossLaneReductionToShuffles( return reductionResult; } +Value xegpu::createReductionNeutralValue(OpBuilder &builder, Location loc, + VectorType type, + vector::CombiningKind kind) { + Type elemTy = type.getElementType(); + + switch (kind) { + case vector::CombiningKind::ADD: + case vector::CombiningKind::XOR: + case vector::CombiningKind::OR: + return arith::ConstantOp::create( + builder, loc, type, + DenseElementsAttr::get(type, builder.getZeroAttr(elemTy))); + + case vector::CombiningKind::MUL: + case vector::CombiningKind::AND: + return arith::ConstantOp::create( + builder, loc, type, + DenseElementsAttr::get(type, builder.getOneAttr(elemTy))); + + case vector::CombiningKind::MINSI: + // Use max signed int value for signed integer min + if (auto intTy = dyn_cast(elemTy)) { + auto maxVal = APInt::getSignedMaxValue(intTy.getWidth()); + return arith::ConstantOp::create( + builder, loc, type, + DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal))); + } + return nullptr; + + case vector::CombiningKind::MINUI: + if (auto intTy = dyn_cast(elemTy)) { + auto maxVal = APInt::getMaxValue(intTy.getWidth()); + return arith::ConstantOp::create( + builder, loc, type, + DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, maxVal))); + } + return nullptr; + + case vector::CombiningKind::MAXSI: + if (auto intTy = dyn_cast(elemTy)) { + auto minVal = APInt::getSignedMinValue(intTy.getWidth()); + return arith::ConstantOp::create( + builder, loc, type, + DenseElementsAttr::get(type, builder.getIntegerAttr(elemTy, minVal))); + } + return nullptr; + + case vector::CombiningKind::MAXUI: + return arith::ConstantOp::create( + builder, loc, type, + DenseElementsAttr::get(type, builder.getZeroAttr(elemTy))); + + case vector::CombiningKind::MINNUMF: + case vector::CombiningKind::MINIMUMF: + // Use +infinity for float min operations + if (auto floatTy = dyn_cast(elemTy)) { + auto posInf = APFloat::getInf(floatTy.getFloatSemantics()); + return arith::ConstantOp::create( + builder, loc, type, + DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, posInf))); + } + return nullptr; + + case vector::CombiningKind::MAXNUMF: + case vector::CombiningKind::MAXIMUMF: + // Use -infinity for float max operations + if (auto floatTy = dyn_cast(elemTy)) { + auto negInf = APFloat::getInf(floatTy.getFloatSemantics(), true); + return arith::ConstantOp::create( + builder, loc, type, + DenseElementsAttr::get(type, builder.getFloatAttr(elemTy, negInf))); + } + return nullptr; + } + return nullptr; +} + /// Explicit instantiations template int xegpu::getLargestDivisor(int dim, ArrayRef candidates, ArrayRef candidateMultiples); diff --git a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir index 83fec045b997..06008ccafbcc 100644 --- a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir +++ b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir @@ -293,14 +293,20 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg // ----- // CHECK-LABEL: gpu.func @vector_reduce_2d( -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG2:[0-9a-zA-Z]+]]: memref<256xf32>) { -// CHECK: %[[ACC_VEC:.*]] = arith.constant dense<1.000000e+00> : vector<16xf32> +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) { +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> +// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<16xindex> +// CHECK: %[[ACC_VEC:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK: %[[ACC_SCALAR:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1]>} 1.000000e+00 : f32 // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout> -// CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout> -> vector<4x16xf32> -// CHECK: %[[LOADED_REDUCED:.*]] = vector.multi_reduction , %[[LOADED]], %[[ACC_VEC]] -// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} [0] : vector<4x16xf32> to vector<16xf32> -// CHECK: %[[LOADED_REDUCED_FOR_CROSS:.*]] = vector.reduction , %[[LOADED_REDUCED]] -// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1]>} : vector<16xf32> into f32 +// CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout> -> vector<4x16xf32> +// CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction , %[[LOADED]], %[[ACC_VEC]] [0] : vector<4x16xf32> to vector<16xf32> +// CHECK: %[[REDUCE_2:.*]] = vector.multi_reduction , %[[REDUCE_1]], %[[ACC_SCALAR]] [0] : vector<16xf32> to f32 +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] +// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} : f32 to vector<16xf32> +// CHECK: xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]] +// CHECK-SAME: <{layout = #xegpu.slice<#xegpu.layout, dims = [0]>}> +// CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1> gpu.module @xevm_test { gpu.func @vector_reduce_2d(%src: memref<4x16xf32>, %dst: memref<256xf32>) { %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1]>} 1.0 : f32 @@ -323,3 +329,44 @@ gpu.module @xevm_test { gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @vector_reduce_2d_with_leading_unit_dims( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) { +// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> +// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<16xindex> +// CHECK: %[[ACC_2D:.*]] = arith.constant dense<0.000000e+00> : vector<1x16xf32> +// CHECK: %[[ACC_1D:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1, 2]>} dense<1.000000e+00> : vector<1xf32> +// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout> +// CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout> -> vector<4x16xf32> +// CHECK: %[[SHAPED:.*]] = vector.shape_cast %[[LOADED]] : vector<4x16xf32> to vector<1x4x16xf32> +// CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction , %[[SHAPED]], %[[ACC_2D]] [1] : vector<1x4x16xf32> to vector<1x16xf32> +// CHECK: %[[REDUCE_2:.*]] = vector.multi_reduction , %[[REDUCE_1]], %[[ACC_1D]] [1] : vector<1x16xf32> to vector<1xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<1xf32> to vector<16xf32> +// CHECK: xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]] +// CHECK-SAME: <{layout = #xegpu.layout}> +// CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1> +gpu.module @xevm_test { + gpu.func @vector_reduce_2d_with_leading_unit_dims(%src: memref<4x16xf32>, %dst: memref<256xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1, 2]>} dense<1.000000e+00> : vector<1xf32> + %tdesc = xegpu.create_nd_tdesc %src : memref<4x16xf32> + -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<4x16xf32, #xegpu.layout> + -> vector<4x16xf32> + %load1 = vector.broadcast %load {layout_result_0 = #xegpu.layout}: vector<4x16xf32> to vector<1x4x16xf32> + %reduce = vector.multi_reduction , %load1, %cst + {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1, 2]>} + [1, 2] : vector<1x4x16xf32> to vector<1xf32> + %reduce_bcast = vector.broadcast %reduce + {layout_result_0 = #xegpu.layout} + : vector<1xf32> to vector<16xf32> + + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<16xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<16xi1> + + xegpu.store %reduce_bcast, %dst[%offset], %mask {layout = #xegpu.layout} : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1> + gpu.return + } +}