From 9f500046511e6e85345d9229bb2d4d2e627bb9d4 Mon Sep 17 00:00:00 2001 From: Nishant Patel Date: Wed, 1 Apr 2026 13:55:11 -0700 Subject: [PATCH] [MLIR][XeGPU] Enhance the peephole optimization to remove the convert_layout after multi-reduction rewrite (#188849) --- .../Transforms/XeGPUPeepHoleOptimizer.cpp | 25 ++++++ .../test/Dialect/XeGPU/peephole-optimize.mlir | 85 +++++++++++++------ 2 files changed, 82 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp index 0ece695aed51..0c7977bb241d 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp @@ -442,6 +442,20 @@ class MultiRed2dOpPattern auto loc = reductionOp.getLoc(); auto acc = reductionOp.getAcc(); + // If the result is scalar after reduction, look for consumer + // convert_layout op and remove it. The layout propagation pass will + // re-install it properly after the decomposition. + Type resultType = reductionOp.getResult().getType(); + if (resultType.isIntOrFloat()) { + for (auto &use : reductionOp.getResult().getUses()) { + if (auto convertLayoutOp = + llvm::dyn_cast(use.getOwner())) { + rewriter.replaceOp(convertLayoutOp, reductionOp.getResult()); + break; + } + } + } + SmallVector accShape(sourceVecType.getShape()); accShape.erase(accShape.begin() + intraLaneDim); Type eTy = sourceVecType.getElementType(); @@ -576,6 +590,17 @@ struct XeGPUPeepHoleOptimizerPass final MLIRContext *ctx = &getContext(); RewritePatternSet emptyPatterns(ctx); (void)applyPatternsGreedily(getOperation(), std::move(emptyPatterns)); + + // Remove the temporary layout after all patterns are applied. + getOperation()->walk([](Operation *op) { + SmallVector attrsToRemove; + for (auto namedAttr : op->getDiscardableAttrs()) { + if (isa(namedAttr.getValue())) + attrsToRemove.push_back(namedAttr.getName()); + } + for (auto attrName : attrsToRemove) + op->removeDiscardableAttr(attrName); + }); } }; diff --git a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir index 06008ccafbcc..f8dfd9a082ba 100644 --- a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir +++ b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir @@ -13,8 +13,7 @@ // CHECK-SAME: {layout = #xegpu.layout} // CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout> // CHECK-SAME: -> vector<16x8xi32> -// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> +// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]] : vector<16x8xi32> to vector<16x16xf16> #a = #xegpu.layout #b = #xegpu.layout #bt = #xegpu.layout @@ -41,8 +40,7 @@ gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x // CHECK: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]] // CHECK-SAME: {layout = #xegpu.layout} // CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> -// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x32xi8> +// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]] : vector<16x8xi32> to vector<16x32xi8> #a = #xegpu.layout #b = #xegpu.layout #bt = #xegpu.layout @@ -76,8 +74,7 @@ gpu.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8 // CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] // CHECK-SAME: <{layout = #xegpu.layout}> : // CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> -// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] : vector<16x8xi32> to vector<16x16xf16> #a = #xegpu.layout #b = #xegpu.layout #bt = #xegpu.layout @@ -119,8 +116,7 @@ gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16 // CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] <{layout = #xegpu.layout< // CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}> : // CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> -// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] : vector<16x8xi32> to vector<16x16xf16> #a = #xegpu.layout #b = #xegpu.layout #bt = #xegpu.layout @@ -165,18 +161,15 @@ gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %ar // CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> // CHECK: %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout, -// CHECK-SAME: offsets = [0, 0], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32> +// CHECK-SAME: {offsets = [0, 0], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32> // CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index // CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] <{layout = #xegpu.layout}> // CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> // CHECK: %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout, -// CHECK-SAME: offsets = [0, 8], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32> -// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout} : vector<32x16xi32> to vector<32x32xf16> +// CHECK-SAME: {offsets = [0, 8], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32> +// CHECK: %{{.*}} = vector.bitcast %[[T10]] : vector<32x16xi32> to vector<32x32xf16> #a = #xegpu.layout #b = #xegpu.layout #bt = #xegpu.layout @@ -237,15 +230,13 @@ gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2 // CHECK-SAME: <{layout = #xegpu.layout}> : // CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> // CHECK: %[[T7:.*]] = vector.bitcast %[[T6]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : -// CHECK-SAME: vector<32x8xi32> to vector<32x16xf16> +// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16> // CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index // CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] // CHECK-SAME: <{layout = #xegpu.layout}> : // CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> // CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : -// CHECK-SAME: vector<32x8xi32> to vector<32x16xf16> +// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16> #a = #xegpu.layout #b = #xegpu.layout #bt = #xegpu.layout @@ -294,16 +285,15 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg // ----- // CHECK-LABEL: gpu.func @vector_reduce_2d( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) { -// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> -// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<16xindex> +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<16xi1> +// CHECK: %[[OFFSET:.*]] = arith.constant dense<0> : vector<16xindex> // CHECK: %[[ACC_VEC:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> -// CHECK: %[[ACC_SCALAR:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1]>} 1.000000e+00 : f32 +// CHECK: %[[ACC_SCALAR:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout> // CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout> -> vector<4x16xf32> // CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction , %[[LOADED]], %[[ACC_VEC]] [0] : vector<4x16xf32> to vector<16xf32> // CHECK: %[[REDUCE_2:.*]] = vector.multi_reduction , %[[REDUCE_1]], %[[ACC_SCALAR]] [0] : vector<16xf32> to f32 -// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] -// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} : f32 to vector<16xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : f32 to vector<16xf32> // CHECK: xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]] // CHECK-SAME: <{layout = #xegpu.slice<#xegpu.layout, dims = [0]>}> // CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1> @@ -333,17 +323,16 @@ gpu.module @xevm_test { // ----- // CHECK-LABEL: gpu.func @vector_reduce_2d_with_leading_unit_dims( // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) { -// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense : vector<16xi1> -// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<16xindex> +// CHECK: %[[MASK:.*]] = arith.constant dense : vector<16xi1> +// CHECK: %[[OFFSET:.*]] = arith.constant dense<0> : vector<16xindex> // CHECK: %[[ACC_2D:.*]] = arith.constant dense<0.000000e+00> : vector<1x16xf32> -// CHECK: %[[ACC_1D:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [1, 2]>} dense<1.000000e+00> : vector<1xf32> +// CHECK: %[[ACC_1D:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32> // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout> // CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout> -> vector<4x16xf32> // CHECK: %[[SHAPED:.*]] = vector.shape_cast %[[LOADED]] : vector<4x16xf32> to vector<1x4x16xf32> // CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction , %[[SHAPED]], %[[ACC_2D]] [1] : vector<1x4x16xf32> to vector<1x16xf32> // CHECK: %[[REDUCE_2:.*]] = vector.multi_reduction , %[[REDUCE_1]], %[[ACC_1D]] [1] : vector<1x16xf32> to vector<1xf32> -// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<1xf32> to vector<16xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : vector<1xf32> to vector<16xf32> // CHECK: xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]] // CHECK-SAME: <{layout = #xegpu.layout}> // CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1> @@ -370,3 +359,43 @@ gpu.module @xevm_test { gpu.return } } + +// ----- +// CHECK-LABEL: gpu.func @reduce_2d_scalar_convert_layout( +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) { +// CHECK: %[[ACC_VEC:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32> +// CHECK: %[[ACC_SCALAR:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout> +// CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout> -> vector<4x16xf32> +// CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction , %[[LOADED]], %[[ACC_VEC]] [0] : vector<4x16xf32> to vector<16xf32> +// CHECK: %[[REDUCE_2:.*]] = vector.multi_reduction , %[[REDUCE_1]], %[[ACC_SCALAR]] [0] : vector<16xf32> to f32 +// CHECK-NOT: xegpu.convert_layout +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : f32 to vector<16xf32> +// CHECK: xegpu.store %[[BCAST]], %[[ARG1]] +gpu.module @xevm_test { + gpu.func @reduce_2d_scalar_convert_layout(%src: memref<4x16xf32>, %dst: memref<256xf32>) { + %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1]>} 1.0 : f32 + %tdesc = xegpu.create_nd_tdesc %src : memref<4x16xf32> + -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout> + %load = xegpu.load_nd %tdesc[0, 0] + : !xegpu.tensor_desc<4x16xf32, #xegpu.layout> + -> vector<4x16xf32> + %reduce = vector.multi_reduction , %load, %cst + {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0, 1]>} + [0, 1] : vector<4x16xf32> to f32 + %cvt = xegpu.convert_layout %reduce + <{input_layout = #xegpu.slice<#xegpu.layout, dims = [0, 1]>, + target_layout = #xegpu.slice<#xegpu.layout, dims = [0, 1]>}> + : f32 + %reduce_bcast = vector.broadcast %cvt + {layout_result_0 = #xegpu.slice<#xegpu.layout, dims = [0]>} + : f32 to vector<16xf32> + + %offset = arith.constant {layout_result_0 = #xegpu.layout} dense<0> : vector<16xindex> + %mask = arith.constant {layout_result_0 = #xegpu.layout} dense<1> : vector<16xi1> + + xegpu.store %reduce_bcast, %dst[%offset], %mask {layout = #xegpu.slice<#xegpu.layout, dims = [0]>} : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1> + gpu.return + } +} +