From c6db35fd343e43f6574062b6b580ec16ff9fb67a Mon Sep 17 00:00:00 2001 From: Charitha Saumya <136391709+charithaintc@users.noreply.github.com> Date: Thu, 26 Feb 2026 14:51:37 -0800 Subject: [PATCH] [mlir][xegpu] Retain order attribute during load + transpose optimization. (#183608) As described in the title `order` attribute is ignored in this transformation causing downstream test failures. --- .../Transforms/XeGPUPeepHoleOptimizer.cpp | 5 +- .../test/Dialect/XeGPU/peephole-optimize.mlir | 106 ++++++++++-------- 2 files changed, 62 insertions(+), 49 deletions(-) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp index 8694bca974df..3b3b11cebe21 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPeepHoleOptimizer.cpp @@ -145,8 +145,9 @@ static xegpu::TensorDescType tryOptimize(xegpu::TensorDescType tdescType, SmallVector supportedShape = {supportedHeight, supportedWidth}; xegpu::LayoutAttr newLayout = xegpu::LayoutAttr::get( - tdescType.getContext(), - tdescType.getLayoutAttr().getLaneLayout().asArrayRef(), {1, 1}); + tdescType.getContext(), tdescType.getLayoutAttr().getLaneLayout(), + DenseI32ArrayAttr::get(tdescType.getContext(), {1, 1}), + tdescType.getLayoutAttr().getOrder()); // Array length can not be larger than 1 for transpose case. return xegpu::TensorDescType::get(supportedShape, newElemTy, arrayLen, tdescType.getBoundaryCheck(), diff --git a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir index 56a4b263255e..83fec045b997 100644 --- a/mlir/test/Dialect/XeGPU/peephole-optimize.mlir +++ b/mlir/test/Dialect/XeGPU/peephole-optimize.mlir @@ -8,14 +8,15 @@ // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xf16> -> index // CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64 // CHECK: %[[BDESC:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C32]]], strides : [%[[C32]], 1] : i64 -// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> // CHECK-NEXT: %[[B:.*]] = xegpu.load_nd %[[BDESC]][%{{.*}}, %[[C16]]] -// CHECK-SAME: {layout = #xegpu.layout} -// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> +// CHECK-SAME: {layout = #xegpu.layout} +// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK-SAME: -> vector<16x8xi32> // CHECK: %[[BITCAST:.*]] = vector.bitcast %[[B]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> #a = #xegpu.layout -#b = #xegpu.layout +#b = #xegpu.layout #bt = #xegpu.layout gpu.module @xevm_module { gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x16xf32> { @@ -36,14 +37,14 @@ gpu.func @no_scf(%arg0: memref<64x64xf16>, %arg1: vector<8x16xf16>) -> vector<8x // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x64xi8> -> index // CHECK: %[[T0:.*]] = arith.index_cast %[[PTR]] : index to i64 // CHECK: %[[T1:.*]] = xegpu.create_nd_tdesc %[[T0]], shape : [64, %[[C16]]], strides : [%[[C16]], 1] : i64 -// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> // CHECK: %[[T2:.*]] = xegpu.load_nd %[[T1]][%{{.*}}, %[[C16]]] -// CHECK-SAME: {layout = #xegpu.layout} -// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> +// CHECK-SAME: {layout = #xegpu.layout} +// CHECK-SAME: : !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> // CHECK: %[[T3:.*]] = vector.bitcast %[[T2]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x32xi8> +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x32xi8> #a = #xegpu.layout -#b = #xegpu.layout +#b = #xegpu.layout #bt = #xegpu.layout #c = #xegpu.layout gpu.module @xevm_module { @@ -69,16 +70,16 @@ gpu.func @no_scf_i8(%arg0: memref<64x64xi8>, %arg1: vector<8x32xi8>) -> vector<8 // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index // CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64 // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%c128, 1] -// CHECK-SAME: : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK-SAME: : i64 -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> // CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) { // CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index // CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] -// CHECK-SAME: <{layout = #xegpu.layout}> : -// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> -// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout} -// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16> +// CHECK-SAME: <{layout = #xegpu.layout}> : +// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> #a = #xegpu.layout -#b = #xegpu.layout +#b = #xegpu.layout #bt = #xegpu.layout gpu.module @xevm_module { gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { @@ -112,16 +113,16 @@ gpu.func @gemm_b_transpose(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16 // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index // CHECK: %[[T3:.*]] = arith.index_cast %[[PTR]] : index to i64 // CHECK: %[[T4:.*]] = xegpu.create_nd_tdesc %[[T3]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 -// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> +// CHECK-SAME: -> !xegpu.tensor_desc<16x8xi32, #xegpu.layout> // CHECK: %{{.*}} = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>) { // CHECK: %[[T7:.*]] = arith.shrui %[[K]], %[[C1]] : index // CHECK-NEXT: %[[T8:.*]] = xegpu.load_nd %[[T4]][%{{.*}}, %[[T7]]] <{layout = #xegpu.layout< -// CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1]>}> : -// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> -// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout} -// CHECK-SAME: : vector<16x8xi32> to vector<16x16xf16> +// CHECK-SAME: lane_layout = [16, 1], lane_data = [1, 1], order = [0, 1]>}> : +// CHECK-SAME: !xegpu.tensor_desc<16x8xi32, #xegpu.layout> -> vector<16x8xi32> +// CHECK-NEXT: %{{.*}} = vector.bitcast %[[T8]] {layout_result_0 = #xegpu.layout} : vector<16x8xi32> to vector<16x16xf16> #a = #xegpu.layout -#b = #xegpu.layout +#b = #xegpu.layout #bt = #xegpu.layout gpu.module @xevm_module { gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { @@ -156,24 +157,28 @@ gpu.func @nested_scf(%arg0: memref<256x256xf16>, %arg1: memref<256x256xf16>, %ar // CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index // CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64 // CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 -// CHECK-SAME: -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout> +// CHECK-SAME: -> !xegpu.tensor_desc<32x8xi32, #xegpu.layout> // CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { // CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index -// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] <{layout = #xegpu.layout}> -// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> +// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] <{layout = #xegpu.layout}> +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> // CHECK: %[[T7:.*]] = vector.insert_strided_slice %[[T6]], %[[CST]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout, offsets = [0, 0], strides = [1, 1]} -// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32> +// CHECK-SAME: {layout_result_0 = #xegpu.layout, +// CHECK-SAME: offsets = [0, 0], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32> // CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index -// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] <{layout = #xegpu.layout}> -// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> +// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] <{layout = #xegpu.layout}> +// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> // CHECK: %[[T10:.*]] = vector.insert_strided_slice %[[T9]], %[[T7]] -// CHECK-SAME: {layout_result_0 = #xegpu.layout, offsets = [0, 8], strides = [1, 1]} -// CHECK-SAME: : vector<32x8xi32> into vector<32x16xi32> -// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout} -// CHECK-SAME: : vector<32x16xi32> to vector<32x32xf16> +// CHECK-SAME: {layout_result_0 = #xegpu.layout, +// CHECK-SAME: offsets = [0, 8], strides = [1, 1]} : vector<32x8xi32> into vector<32x16xi32> +// CHECK: %{{.*}} = vector.bitcast %[[T10]] {layout_result_0 = #xegpu.layout} : vector<32x16xi32> to vector<32x32xf16> #a = #xegpu.layout -#b = #xegpu.layout +#b = #xegpu.layout #bt = #xegpu.layout gpu.module @xevm_module { gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) { @@ -219,23 +224,30 @@ gpu.func @large_loads(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2 // CHECK: %[[C128:.*]] = arith.constant 128 : index // CHECK: %[[C8:.*]] = arith.constant 8 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<256x256xf16> -> index +// CHECK: %[[PTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : +// CHECK-SAME: memref<256x256xf16> -> index // CHECK: %[[T2:.*]] = arith.index_cast %[[PTR]] : index to i64 -// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], strides : [%[[C128]], 1] : i64 -> -// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { +// CHECK: %[[T3:.*]] = xegpu.create_nd_tdesc %[[T2]], shape : [256, %[[C128]]], +// CHECK-SAME: strides : [%[[C128]], 1] : i64 -> +// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout> +// CHECK: %{{.*}}:4 = scf.for %[[K:.*]] = %{{.*}} iter_args(%{{.*}}) -> +// CHECK-SAME: (vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>, vector<8x16xf32>) { // CHECK: %[[T5:.*]] = arith.shrui %[[K]], %[[C1]] : index -// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] <{layout = #xegpu.layout}> -// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> -// CHECK: %[[T7:.*]] = vector.bitcast %[[T6]] {layout_result_0 = #xegpu.layout} -// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16> +// CHECK: %[[T6:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T5]]] +// CHECK-SAME: <{layout = #xegpu.layout}> : +// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> +// CHECK: %[[T7:.*]] = vector.bitcast %[[T6]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : +// CHECK-SAME: vector<32x8xi32> to vector<32x16xf16> // CHECK: %[[T8:.*]] = arith.addi %[[T5]], %[[C8]] : index -// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] <{layout = #xegpu.layout}> -// CHECK-SAME: : !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> -// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] {layout_result_0 = #xegpu.layout} -// CHECK-SAME: : vector<32x8xi32> to vector<32x16xf16> +// CHECK: %[[T9:.*]] = xegpu.load_nd %[[T3]][%{{.*}}, %[[T8]]] +// CHECK-SAME: <{layout = #xegpu.layout}> : +// CHECK-SAME: !xegpu.tensor_desc<32x8xi32, #xegpu.layout> -> vector<32x8xi32> +// CHECK: %[[T10:.*]] = vector.bitcast %[[T9]] +// CHECK-SAME: {layout_result_0 = #xegpu.layout} : +// CHECK-SAME: vector<32x8xi32> to vector<32x16xf16> #a = #xegpu.layout -#b = #xegpu.layout +#b = #xegpu.layout #bt = #xegpu.layout gpu.module @xevm_module { gpu.func @array_length(%arg0: vector<8x16xf16>, %arg1: memref<256x256xf16>, %arg2: memref<256x256xf32>) {