[MLIR][XeGPU] Extend Wg-to-Sg Distribution of Multi-Reduction Op for round-robin layout (#189988)

This PR enhance the multi-reduction op pattern of wg-to-sg distribution
pass:
1. allows each sg have multiple distribution of sg_data tiles.
2. expand the slm buffer size.
3. construct the layout based on the partial reduced vector and use
layout.computeDistributedCoords() to compute coordinates. the layout is
constructed so that the store is cooperative, and load overlapps with
neighbour threads.
4. perform save and load.
This commit is contained in:
Jianhui Li 2026-04-06 14:07:50 -07:00 committed by GitHub
parent 97d50c1490
commit 9bddf47198
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 123 additions and 144 deletions

View File

@ -1255,7 +1255,6 @@ struct WgToSgMultiDimReductionOp
bool isScalarResult = !dstVecType;
auto originalSrcShape = srcType.getShape();
int srcVecRank = originalSrcShape.size();
Type elemTy = srcType.getElementType();
xegpu::DistributeLayoutAttr layout =
@ -1268,9 +1267,11 @@ struct WgToSgMultiDimReductionOp
// Get sg_layout and sg_data from the parent layout
SmallVector<int64_t> sgLayout;
SmallVector<int64_t> sgData;
xegpu::DistributeLayoutAttr parentLayout;
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
parentLayout = sliceAttr.getParent();
sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
sgData = parentLayout.getEffectiveSgDataAsInt();
} else
return rewriter.notifyMatchFailure(
op, "Reduction should have SliceAttr layout");
@ -1330,26 +1331,33 @@ struct WgToSgMultiDimReductionOp
return success();
}
// Step 2: cross-subgroup reduction using SLM
// Step 2: cross-subgroup reduction using SLM - allocating slm memory
auto slmStoreDataShape = sgSrcShape;
for (int64_t dim : reductionDims)
slmStoreDataShape[dim] = 1;
VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
Value slmStoreData;
if (isScalarResult) {
// Scalar result: broadcast scalar to vector<1x...x1> for SLM store
slmStoreData = vector::BroadcastOp::create(
rewriter, loc, slmStoreDataType, localReductions[0]);
} else {
slmStoreData = vector::ShapeCastOp::create(
rewriter, loc, slmStoreDataType, localReductions[0]);
SmallVector<Value> slmStoreData;
for (auto localResult : localReductions) {
if (isScalarResult) {
// Scalar result: broadcast scalar to vector<1x...x1> for SLM store
slmStoreData.push_back(vector::BroadcastOp::create(
rewriter, loc, slmStoreDataType, localResult));
} else {
slmStoreData.push_back(vector::ShapeCastOp::create(
rewriter, loc, slmStoreDataType, localResult));
}
}
// for reduction dimension, SLM stores partial results from each subgroup
SmallVector<int64_t> slmShape(originalSrcShape.begin(),
originalSrcShape.end());
// for reduction dimension, SLM stores partial results from each subgroup
for (int64_t dim : reductionDims)
SmallVector<int> slmSgData(sgData.begin(), sgData.end());
SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
for (int dim : reductionDims) {
slmShape[dim] = sgLayout[dim];
slmSgData[dim] = 1;
}
xegpu::LayoutAttr slmStoreLayout =
xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
// Allocate SLM
auto bitWidth = elemTy.getIntOrFloatBitWidth();
@ -1363,82 +1371,61 @@ struct WgToSgMultiDimReductionOp
auto memDesc =
xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
// if localReductions have more than 1 result, not support
if (localReductions.size() > 1) {
return rewriter.notifyMatchFailure(
op,
"Multiple local reductions not supported in current implementation.");
}
// Step 4: Store local results to SLM
// Step 3: Store local results to SLM
auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
rewriter.getIndexType(), nullptr);
// Convert sgLayout to Values for delinearizeIndex
SmallVector<Value> sgLayoutValues;
for (int64_t dim : sgLayout)
sgLayoutValues.push_back(
arith::ConstantIndexOp::create(rewriter, loc, dim));
auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
sgLayoutValues);
if (failed(sgIdsResult))
auto slmStoreCoords =
slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
if (failed(slmStoreCoords))
return failure();
SmallVector<Value> sgIds = *sgIdsResult;
auto getSlmOffsets = [&](int64_t reductionDimStride) {
SmallVector<OpFoldResult> offsets;
offsets.reserve(srcVecRank);
for (int i = 0; i < srcVecRank; ++i) {
Value dimVal = sgIds[i];
int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
? reductionDimStride
: sgSrcShape[i];
Value strideVal =
arith::ConstantIndexOp::create(rewriter, loc, sgDataStride);
Value offsetVal =
arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
offsets.push_back(offsetVal);
}
return offsets;
};
SmallVector<OpFoldResult> slmStoreOffsets =
getSlmOffsets(/*reductionDimStride=*/1);
xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
memDesc.getResult(), slmStoreOffsets,
/*layout=*/nullptr);
for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
coordOfr,
/*layout=*/nullptr);
}
gpu::BarrierOp::create(rewriter, loc);
// Step 5: Load from SLM for final reduction
// Step 4: Load from SLM for final reduction
SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
for (int64_t dim : reductionDims)
for (int64_t dim : reductionDims) {
slmLoadDataShape[dim] = slmShape[dim];
SmallVector<OpFoldResult> slmLoadOffsets =
getSlmOffsets(/*reductionDimStride=*/0);
slmSgData[dim] = slmShape[dim];
}
xegpu::LayoutAttr slmLoadLayout =
xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
auto slmLoadCoords =
slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
if (failed(slmLoadCoords))
return failure();
VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
auto slmLoadOp = xegpu::LoadMatrixOp::create(
rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
/*layout=*/nullptr);
SmallVector<Value> slmLoadData;
for (auto coord : *slmLoadCoords) {
SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
slmLoadData.push_back(xegpu::LoadMatrixOp::create(
rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
/*layout=*/nullptr));
}
// Step 6: Perform final reduction with neutral accumulator
// Step 5: Perform final reduction with neutral accumulator and add the
// original accumulator at the end
Value neutralFinalAcc = xegpu::createReductionNeutralValue(
rewriter, loc, sgDstType, op.getKind());
auto finalReduce = vector::MultiDimReductionOp::create(
rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
neutralFinalAcc, reductionDims);
// Step 7: Add the original accumulator at the end
auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(),
finalReduce.getResult(),
adaptor.getAcc()[0]);
rewriter.replaceOp(op, finalResult);
SmallVector<Value> finalResults;
for (size_t i = 0; i < slmLoadData.size(); ++i) {
auto loaded = slmLoadData[i];
auto finalReduce = vector::MultiDimReductionOp::create(
rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
reductionDims);
finalResults.push_back(vector::makeArithReduction(
rewriter, loc, op.getKind(), finalReduce.getResult(),
adaptor.getAcc()[i]));
}
rewriter.replaceOpWithMultiple(op, {finalResults});
return success();
}
};

View File

@ -165,6 +165,58 @@ gpu.module @test_distribution {
gpu.return
}
// CHECK-LABEL: gpu.func @reduction_cross_sg_rr
gpu.func @reduction_cross_sg_rr(%arg0: memref<2048xf32, 1>) kernel {
// CHECK: %[[CST_OFFSETS0:.*]] = arith.constant dense<0> : vector<4x16xindex>
// CHECK: %[[CST_OFFSETS1:.*]] = arith.constant dense<0> : vector<4x16xindex>
// CHECK: %[[CST_ACC0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[CST_ACC1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[CST_MASK0:.*]] = arith.constant dense<true> : vector<4x16xi1>
// CHECK: %[[CST_MASK1:.*]] = arith.constant dense<true> : vector<4x16xi1>
//
// CHECK: %[[LOAD0:.*]] = xegpu.load %arg0[%[[CST_OFFSETS0]]], %[[CST_MASK0]]
// CHECK-SAME: -> vector<4x16xf32>
// CHECK: %[[LOAD1:.*]] = xegpu.load %arg0[%[[CST_OFFSETS1]]], %[[CST_MASK1]]
// CHECK-SAME: -> vector<4x16xf32>
//
// Local reductions
// CHECK: %[[NEUTRAL0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[LOCAL_RED0:.*]] = vector.multi_reduction <add>, %[[LOAD0]], %[[NEUTRAL0]] [1] : vector<4x16xf32> to vector<4xf32>
// CHECK: %[[NEUTRAL1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[LOCAL_RED1:.*]] = vector.multi_reduction <add>, %[[LOAD1]], %[[NEUTRAL1]] [1] : vector<4x16xf32> to vector<4xf32>
//
// Shape cast for SLM store
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[LOCAL_RED0]] : vector<4xf32> to vector<4x1xf32>
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[LOCAL_RED1]] : vector<4xf32> to vector<4x1xf32>
//
// SLM allocation and mem_desc
// CHECK: %[[SLM:.*]] = memref.alloca() : memref<512xi8, 3>
// CHECK: %[[MEMDESC:.*]] = xegpu.create_mem_desc %[[SLM]] : memref<512xi8, 3> -> !xegpu.mem_desc<8x16xf32>
//
// Store to SLM
// CHECK: xegpu.store_matrix %[[SC0]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
// CHECK: xegpu.store_matrix %[[SC1]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
// CHECK: gpu.barrier
//
// Load from SLM
// CHECK: %[[SLM_LOAD0:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
// CHECK: %[[SLM_LOAD1:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
//
// Final reduction
// CHECK: %[[FINAL_NEUTRAL:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
// CHECK: %[[FINAL_RED0:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD0]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
// CHECK: %[[RES0:.*]] = arith.addf %[[FINAL_RED0]], %[[CST_ACC0]] : vector<4xf32>
// CHECK: %[[FINAL_RED1:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD1]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
// CHECK: %[[RES1:.*]] = arith.addf %[[FINAL_RED1]], %[[CST_ACC1]] : vector<4xf32>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<0> : vector<8x256xindex>
%acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} dense<0.000000e+00> : vector<8xf32>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<true> : vector<8x256xi1>
%val = xegpu.load %arg0[%offset], %mask <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
%reduce = vector.multi_reduction <add>, %val, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
gpu.return
}
// CHECK-LABEL: splat_constant
gpu.func @splat_constant() {
// CHECK-COUNT-2: %[[CST:.*]] = arith.constant dense<0> : vector<4xindex>

View File

@ -1,13 +1,4 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 4)>
// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 4)>
// CHECK-DAG: #map2 = affine_map<()[s0] -> (s0 floordiv 32)>
// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 mod 32)>
// CHECK-DAG: #map4 = affine_map<()[s0] -> (0)>
// CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)>
// CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)>
// CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@ -681,18 +672,9 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[AFF0:.*]] = affine.apply #map2()[%[[SGID]]]
// CHECK-DAG: %[[AFF1:.*]] = affine.apply #map3()[%[[SGID]]]
// CHECK-DAG: %[[AFF2:.*]] = affine.apply #map4()[%[[SGID]]]
// CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index
// CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index
// CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]][%[[ROW]], %[[COL0]], %[[COL1]]] : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index
// CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]]{{.*}} : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[ROW_L:.*]] = arith.muli %[[AFF0]], %[[C1C:.*]] : index
// CHECK-DAG: %[[COL0_L:.*]] = arith.muli %[[AFF1]], %[[C0:.*]] : index
// CHECK-DAG: %[[COL1_L:.*]] = arith.muli %[[AFF2]], %[[C32B:.*]] : index
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ROW_L]], %[[COL0_L]], %[[COL1_L]]] : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32>
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [1] : vector<1x32x32xf32> to vector<1x32xf32>
// CHECK-DAG: %[[ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x32xf32>
@ -725,15 +707,9 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
// CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
// CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[AFFINE2]], %[[C32_1:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[ZERO_ROW:.*]] = arith.muli %[[AFFINE1]], %[[C0:.*]] : index
// CHECK-DAG: %[[COL_OFFSET2:.*]] = arith.muli %[[AFFINE2]], %[[C32_2:.*]] : index
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ZERO_ROW]], %[[COL_OFFSET2]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
// CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<32xf32>
@ -761,31 +737,9 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<2x2x4x4xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[I0:.*]] = arith.muli %[[AFFINE0]], %[[C1]] : index
// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[I1:.*]] = arith.muli %[[AFFINE2]], %[[C1_0]] : index
// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[I2:.*]] = arith.muli %[[AFFINE4]], %[[C1_1]] : index
// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[I3:.*]] = arith.muli %[[AFFINE5]], %[[C1_2]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[I0]], %[[I1]], %[[I2]], %[[I3]]] : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C1_3]] : index
// CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C1_4]] : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0]] : index
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_0]] : index
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32>
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<1x1x4x4xf32> to vector<1x1xf32>
// CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x1xf32>
@ -811,23 +765,9 @@ gpu.module @test_distribution {
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<65536xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<65536xi8, 3> -> !xegpu.mem_desc<32x32x4x4xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
// CHECK-DAG: %[[R0:.*]] = arith.muli %[[AFFINE0]], %[[C16_0:.*]] : index
// CHECK-DAG: %[[R1:.*]] = arith.muli %[[AFFINE2]], %[[C16_1:.*]] : index
// CHECK-DAG: %[[R2:.*]] = arith.muli %[[AFFINE4]], %[[C1_0:.*]] : index
// CHECK-DAG: %[[R3:.*]] = arith.muli %[[AFFINE5]], %[[C1_1:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[R0]], %[[R1]], %[[R2]], %[[R3]]] : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C16_2:.*]] : index
// CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C16_3:.*]] : index
// CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0_0:.*]] : index
// CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_1:.*]] : index
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32>
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<16x16xf32>
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<16x16x4x4xf32> to vector<16x16xf32>
// CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<16x16xf32>