[MLIR][XeGPU] Enhance the peephole optimization to remove the convert_layout after multi-reduction rewrite (#188849)

This commit is contained in:
Nishant Patel 2026-04-01 13:55:11 -07:00 committed by GitHub
parent d6d0876d1a
commit 9f50004651
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 82 additions and 28 deletions

View File

@ -442,6 +442,20 @@ class MultiRed2dOpPattern
auto loc = reductionOp.getLoc();
auto acc = reductionOp.getAcc();
// If the result is scalar after reduction, look for consumer
// convert_layout op and remove it. The layout propagation pass will
// re-install it properly after the decomposition.
Type resultType = reductionOp.getResult().getType();
if (resultType.isIntOrFloat()) {
for (auto &use : reductionOp.getResult().getUses()) {
if (auto convertLayoutOp =
llvm::dyn_cast<xegpu::ConvertLayoutOp>(use.getOwner())) {
rewriter.replaceOp(convertLayoutOp, reductionOp.getResult());
break;
}
}
}
SmallVector<int64_t> accShape(sourceVecType.getShape());
accShape.erase(accShape.begin() + intraLaneDim);
Type eTy = sourceVecType.getElementType();
@ -576,6 +590,17 @@ struct XeGPUPeepHoleOptimizerPass final
MLIRContext *ctx = &getContext();
RewritePatternSet emptyPatterns(ctx);
(void)applyPatternsGreedily(getOperation(), std::move(emptyPatterns));
// Remove the temporary layout after all patterns are applied.
getOperation()->walk([](Operation *op) {
SmallVector<StringAttr> attrsToRemove;
for (auto namedAttr : op->getDiscardableAttrs()) {
if (isa<xegpu::DistributeLayoutAttr>(namedAttr.getValue()))
attrsToRemove.push_back(namedAttr.getName());
}
for (auto attrName : attrsToRemove)
op->removeDiscardableAttr(attrName);
});
}
};

View File

@ -13,8 +13,7 @@
// CHECK-SAME: {layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}
// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>>
// CHECK-SAME: -> vector<16x8xi32>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>} : vector<16x8xi32> to vector<16x16xf16>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]] : vector<16x8xi32> to vector<16x16xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@ -41,8 +40,7 @@ gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x
// CHECK: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]]
// CHECK-SAME: {layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}
// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>> -> vector<16x8xi32>
// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4], order = [0, 1]>} : vector<16x8xi32> to vector<16x32xi8>
// CHECK: %[[T3:.*]] = vector.bitcast %[[T2]] : vector<16x8xi32> to vector<16x32xi8>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 4], order = [0, 1]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [4, 1]>
@ -76,8 +74,7 @@ gpu.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8
// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]]
// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}> :
// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>> -> vector<16x8xi32>
// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2],
// CHECK-SAME: order = [0, 1]>} : vector<16x8xi32> to vector<16x16xf16>
// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] : vector<16x8xi32> to vector<16x16xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@ -119,8 +116,7 @@ gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16
// CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] <{layout = #xegpu.layout<
// CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}> :
// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>> -> vector<16x8xi32>
// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2],
// CHECK-SAME: order = [0, 1]>} : vector<16x8xi32> to vector<16x16xf16>
// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] : vector<16x8xi32> to vector<16x16xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@ -165,18 +161,15 @@ gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %ar
// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1],
// CHECK-SAME: order = [0, 1]>> -> vector<32x8xi32>
// CHECK: %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>,
// CHECK-SAME: offsets = [0, 0], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32>
// CHECK-SAME: {offsets = [0, 0], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32>
// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] <{layout = #xegpu.layout<lane_layout = [16, 1],
// CHECK-SAME: lane_data = [1, 1], order = [0, 1]>}>
// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1],
// CHECK-SAME: lane_data = [1, 1], order = [0, 1]>> -> vector<32x8xi32>
// CHECK: %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>,
// CHECK-SAME: offsets = [0, 8], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32>
// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1],
// CHECK-SAME: lane_data = [1, 2], order = [0, 1]>} : vector<32x16xi32> to vector<32x32xf16>
// CHECK-SAME: {offsets = [0, 8], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32>
// CHECK: %{{.*}} = vector.bitcast %[[T10]] : vector<32x16xi32> to vector<32x32xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@ -237,15 +230,13 @@ gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2
// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}> :
// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>> -> vector<32x8xi32>
// CHECK: %[[T7:.*]] = vector.bitcast %[[T6]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>} :
// CHECK-SAME: vector<32x8xi32> to vector<32x16xf16>
// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16>
// CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index
// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]]
// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}> :
// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>> -> vector<32x8xi32>
// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>} :
// CHECK-SAME: vector<32x8xi32> to vector<32x16xf16>
// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16>
#a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
#b = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2], order = [0, 1]>
#bt = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
@ -294,16 +285,15 @@ gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg
// -----
// CHECK-LABEL: gpu.func @vector_reduce_2d(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK: %[[OFFSET:.*]] = arith.constant dense<0> : vector<16xindex>
// CHECK: %[[ACC_VEC:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK: %[[ACC_SCALAR:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} 1.000000e+00 : f32
// CHECK: %[[ACC_SCALAR:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
// CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[LOADED]], %[[ACC_VEC]] [0] : vector<4x16xf32> to vector<16xf32>
// CHECK: %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_SCALAR]] [0] : vector<16xf32> to f32
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]]
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : f32 to vector<16xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : f32 to vector<16xf32>
// CHECK: xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]]
// CHECK-SAME: <{layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}>
// CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
@ -333,17 +323,16 @@ gpu.module @xevm_test {
// -----
// CHECK-LABEL: gpu.func @vector_reduce_2d_with_leading_unit_dims(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
// CHECK: %[[OFFSET:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<16xi1>
// CHECK: %[[OFFSET:.*]] = arith.constant dense<0> : vector<16xindex>
// CHECK: %[[ACC_2D:.*]] = arith.constant dense<0.000000e+00> : vector<1x16xf32>
// CHECK: %[[ACC_1D:.*]] = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>, dims = [1, 2]>} dense<1.000000e+00> : vector<1xf32>
// CHECK: %[[ACC_1D:.*]] = arith.constant dense<1.000000e+00> : vector<1xf32>
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
// CHECK: %[[SHAPED:.*]] = vector.shape_cast %[[LOADED]] : vector<4x16xf32> to vector<1x4x16xf32>
// CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[SHAPED]], %[[ACC_2D]] [1] : vector<1x4x16xf32> to vector<1x16xf32>
// CHECK: %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_1D]] [1] : vector<1x16xf32> to vector<1xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]]
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : vector<1xf32> to vector<16xf32>
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : vector<1xf32> to vector<16xf32>
// CHECK: xegpu.store %[[BCAST]], %[[ARG1]][%[[OFFSET]]], %[[MASK]]
// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
// CHECK-SAME: : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
@ -370,3 +359,43 @@ gpu.module @xevm_test {
gpu.return
}
}
// -----
// CHECK-LABEL: gpu.func @reduce_2d_scalar_convert_layout(
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<4x16xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
// CHECK: %[[ACC_VEC:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
// CHECK: %[[ACC_SCALAR:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]] : memref<4x16xf32> -> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
// CHECK: %[[LOADED:.*]] = xegpu.load_nd %[[TDESC]][0, 0] : !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<4x16xf32>
// CHECK: %[[REDUCE_1:.*]] = vector.multi_reduction <add>, %[[LOADED]], %[[ACC_VEC]] [0] : vector<4x16xf32> to vector<16xf32>
// CHECK: %[[REDUCE_2:.*]] = vector.multi_reduction <add>, %[[REDUCE_1]], %[[ACC_SCALAR]] [0] : vector<16xf32> to f32
// CHECK-NOT: xegpu.convert_layout
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[REDUCE_2]] : f32 to vector<16xf32>
// CHECK: xegpu.store %[[BCAST]], %[[ARG1]]
gpu.module @xevm_test {
gpu.func @reduce_2d_scalar_convert_layout(%src: memref<4x16xf32>, %dst: memref<256xf32>) {
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>} 1.0 : f32
%tdesc = xegpu.create_nd_tdesc %src : memref<4x16xf32>
-> !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<4x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-> vector<4x16xf32>
%reduce = vector.multi_reduction <add>, %load, %cst
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>}
[0, 1] : vector<4x16xf32> to f32
%cvt = xegpu.convert_layout %reduce
<{input_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>,
target_layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0, 1]>}>
: f32
%reduce_bcast = vector.broadcast %cvt
{layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
: f32 to vector<16xf32>
%offset = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<0> : vector<16xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<1> : vector<16xi1>
xegpu.store %reduce_bcast, %dst[%offset], %mask {layout = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : vector<16xf32>, memref<256xf32>, vector<16xindex>, vector<16xi1>
gpu.return
}
}