[MLIR][XeGPU] Extend Wg-to-Sg Distribution of Multi-Reduction Op for round-robin layout (#189988)
This PR enhance the multi-reduction op pattern of wg-to-sg distribution pass: 1. allows each sg have multiple distribution of sg_data tiles. 2. expand the slm buffer size. 3. construct the layout based on the partial reduced vector and use layout.computeDistributedCoords() to compute coordinates. the layout is constructed so that the store is cooperative, and load overlapps with neighbour threads. 4. perform save and load.
This commit is contained in:
parent
97d50c1490
commit
9bddf47198
@ -1255,7 +1255,6 @@ struct WgToSgMultiDimReductionOp
|
||||
bool isScalarResult = !dstVecType;
|
||||
|
||||
auto originalSrcShape = srcType.getShape();
|
||||
int srcVecRank = originalSrcShape.size();
|
||||
Type elemTy = srcType.getElementType();
|
||||
|
||||
xegpu::DistributeLayoutAttr layout =
|
||||
@ -1268,9 +1267,11 @@ struct WgToSgMultiDimReductionOp
|
||||
// Get sg_layout and sg_data from the parent layout
|
||||
SmallVector<int64_t> sgLayout;
|
||||
SmallVector<int64_t> sgData;
|
||||
xegpu::DistributeLayoutAttr parentLayout;
|
||||
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
|
||||
sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
|
||||
sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
|
||||
parentLayout = sliceAttr.getParent();
|
||||
sgLayout = parentLayout.getEffectiveSgLayoutAsInt();
|
||||
sgData = parentLayout.getEffectiveSgDataAsInt();
|
||||
} else
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Reduction should have SliceAttr layout");
|
||||
@ -1330,26 +1331,33 @@ struct WgToSgMultiDimReductionOp
|
||||
return success();
|
||||
}
|
||||
|
||||
// Step 2: cross-subgroup reduction using SLM
|
||||
// Step 2: cross-subgroup reduction using SLM - allocating slm memory
|
||||
auto slmStoreDataShape = sgSrcShape;
|
||||
for (int64_t dim : reductionDims)
|
||||
slmStoreDataShape[dim] = 1;
|
||||
VectorType slmStoreDataType = VectorType::get(slmStoreDataShape, elemTy);
|
||||
Value slmStoreData;
|
||||
if (isScalarResult) {
|
||||
// Scalar result: broadcast scalar to vector<1x...x1> for SLM store
|
||||
slmStoreData = vector::BroadcastOp::create(
|
||||
rewriter, loc, slmStoreDataType, localReductions[0]);
|
||||
} else {
|
||||
slmStoreData = vector::ShapeCastOp::create(
|
||||
rewriter, loc, slmStoreDataType, localReductions[0]);
|
||||
SmallVector<Value> slmStoreData;
|
||||
for (auto localResult : localReductions) {
|
||||
if (isScalarResult) {
|
||||
// Scalar result: broadcast scalar to vector<1x...x1> for SLM store
|
||||
slmStoreData.push_back(vector::BroadcastOp::create(
|
||||
rewriter, loc, slmStoreDataType, localResult));
|
||||
} else {
|
||||
slmStoreData.push_back(vector::ShapeCastOp::create(
|
||||
rewriter, loc, slmStoreDataType, localResult));
|
||||
}
|
||||
}
|
||||
|
||||
// for reduction dimension, SLM stores partial results from each subgroup
|
||||
SmallVector<int64_t> slmShape(originalSrcShape.begin(),
|
||||
originalSrcShape.end());
|
||||
// for reduction dimension, SLM stores partial results from each subgroup
|
||||
for (int64_t dim : reductionDims)
|
||||
SmallVector<int> slmSgData(sgData.begin(), sgData.end());
|
||||
SmallVector<int> slmSgLayout(sgLayout.begin(), sgLayout.end());
|
||||
for (int dim : reductionDims) {
|
||||
slmShape[dim] = sgLayout[dim];
|
||||
slmSgData[dim] = 1;
|
||||
}
|
||||
xegpu::LayoutAttr slmStoreLayout =
|
||||
xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
|
||||
|
||||
// Allocate SLM
|
||||
auto bitWidth = elemTy.getIntOrFloatBitWidth();
|
||||
@ -1363,82 +1371,61 @@ struct WgToSgMultiDimReductionOp
|
||||
auto memDesc =
|
||||
xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
|
||||
|
||||
// if localReductions have more than 1 result, not support
|
||||
if (localReductions.size() > 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Multiple local reductions not supported in current implementation.");
|
||||
}
|
||||
|
||||
// Step 4: Store local results to SLM
|
||||
// Step 3: Store local results to SLM
|
||||
auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
|
||||
rewriter.getIndexType(), nullptr);
|
||||
|
||||
// Convert sgLayout to Values for delinearizeIndex
|
||||
SmallVector<Value> sgLayoutValues;
|
||||
for (int64_t dim : sgLayout)
|
||||
sgLayoutValues.push_back(
|
||||
arith::ConstantIndexOp::create(rewriter, loc, dim));
|
||||
|
||||
auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
|
||||
sgLayoutValues);
|
||||
if (failed(sgIdsResult))
|
||||
auto slmStoreCoords =
|
||||
slmStoreLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
|
||||
if (failed(slmStoreCoords))
|
||||
return failure();
|
||||
SmallVector<Value> sgIds = *sgIdsResult;
|
||||
|
||||
auto getSlmOffsets = [&](int64_t reductionDimStride) {
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
offsets.reserve(srcVecRank);
|
||||
for (int i = 0; i < srcVecRank; ++i) {
|
||||
Value dimVal = sgIds[i];
|
||||
int64_t sgDataStride = (llvm::is_contained(reductionDims, i))
|
||||
? reductionDimStride
|
||||
: sgSrcShape[i];
|
||||
Value strideVal =
|
||||
arith::ConstantIndexOp::create(rewriter, loc, sgDataStride);
|
||||
Value offsetVal =
|
||||
arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
|
||||
offsets.push_back(offsetVal);
|
||||
}
|
||||
return offsets;
|
||||
};
|
||||
|
||||
SmallVector<OpFoldResult> slmStoreOffsets =
|
||||
getSlmOffsets(/*reductionDimStride=*/1);
|
||||
|
||||
xegpu::StoreMatrixOp::create(rewriter, loc, slmStoreData,
|
||||
memDesc.getResult(), slmStoreOffsets,
|
||||
/*layout=*/nullptr);
|
||||
for (auto [data, coord] : llvm::zip(slmStoreData, *slmStoreCoords)) {
|
||||
SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
|
||||
xegpu::StoreMatrixOp::create(rewriter, loc, data, memDesc.getResult(),
|
||||
coordOfr,
|
||||
/*layout=*/nullptr);
|
||||
}
|
||||
|
||||
gpu::BarrierOp::create(rewriter, loc);
|
||||
|
||||
// Step 5: Load from SLM for final reduction
|
||||
// Step 4: Load from SLM for final reduction
|
||||
SmallVector<int64_t> slmLoadDataShape(sgSrcShape.begin(), sgSrcShape.end());
|
||||
for (int64_t dim : reductionDims)
|
||||
for (int64_t dim : reductionDims) {
|
||||
slmLoadDataShape[dim] = slmShape[dim];
|
||||
|
||||
SmallVector<OpFoldResult> slmLoadOffsets =
|
||||
getSlmOffsets(/*reductionDimStride=*/0);
|
||||
slmSgData[dim] = slmShape[dim];
|
||||
}
|
||||
xegpu::LayoutAttr slmLoadLayout =
|
||||
xegpu::LayoutAttr::get(rewriter.getContext(), slmSgLayout, slmSgData);
|
||||
auto slmLoadCoords =
|
||||
slmLoadLayout.computeDistributedCoords(rewriter, loc, sgId, slmShape);
|
||||
if (failed(slmLoadCoords))
|
||||
return failure();
|
||||
|
||||
VectorType slmLoadType = VectorType::get(slmLoadDataShape, elemTy);
|
||||
auto slmLoadOp = xegpu::LoadMatrixOp::create(
|
||||
rewriter, loc, slmLoadType, memDesc.getResult(), slmLoadOffsets,
|
||||
/*layout=*/nullptr);
|
||||
SmallVector<Value> slmLoadData;
|
||||
for (auto coord : *slmLoadCoords) {
|
||||
SmallVector<OpFoldResult> coordOfr(coord.begin(), coord.end());
|
||||
slmLoadData.push_back(xegpu::LoadMatrixOp::create(
|
||||
rewriter, loc, slmLoadType, memDesc.getResult(), coordOfr,
|
||||
/*layout=*/nullptr));
|
||||
}
|
||||
|
||||
// Step 6: Perform final reduction with neutral accumulator
|
||||
// Step 5: Perform final reduction with neutral accumulator and add the
|
||||
// original accumulator at the end
|
||||
Value neutralFinalAcc = xegpu::createReductionNeutralValue(
|
||||
rewriter, loc, sgDstType, op.getKind());
|
||||
|
||||
auto finalReduce = vector::MultiDimReductionOp::create(
|
||||
rewriter, loc, sgDstType, op.getKind(), slmLoadOp.getResult(),
|
||||
neutralFinalAcc, reductionDims);
|
||||
|
||||
// Step 7: Add the original accumulator at the end
|
||||
auto finalResult = vector::makeArithReduction(rewriter, loc, op.getKind(),
|
||||
finalReduce.getResult(),
|
||||
adaptor.getAcc()[0]);
|
||||
|
||||
rewriter.replaceOp(op, finalResult);
|
||||
SmallVector<Value> finalResults;
|
||||
for (size_t i = 0; i < slmLoadData.size(); ++i) {
|
||||
auto loaded = slmLoadData[i];
|
||||
auto finalReduce = vector::MultiDimReductionOp::create(
|
||||
rewriter, loc, sgDstType, op.getKind(), loaded, neutralFinalAcc,
|
||||
reductionDims);
|
||||
finalResults.push_back(vector::makeArithReduction(
|
||||
rewriter, loc, op.getKind(), finalReduce.getResult(),
|
||||
adaptor.getAcc()[i]));
|
||||
}
|
||||
rewriter.replaceOpWithMultiple(op, {finalResults});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -165,6 +165,58 @@ gpu.module @test_distribution {
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: gpu.func @reduction_cross_sg_rr
|
||||
gpu.func @reduction_cross_sg_rr(%arg0: memref<2048xf32, 1>) kernel {
|
||||
// CHECK: %[[CST_OFFSETS0:.*]] = arith.constant dense<0> : vector<4x16xindex>
|
||||
// CHECK: %[[CST_OFFSETS1:.*]] = arith.constant dense<0> : vector<4x16xindex>
|
||||
// CHECK: %[[CST_ACC0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
|
||||
// CHECK: %[[CST_ACC1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
|
||||
// CHECK: %[[CST_MASK0:.*]] = arith.constant dense<true> : vector<4x16xi1>
|
||||
// CHECK: %[[CST_MASK1:.*]] = arith.constant dense<true> : vector<4x16xi1>
|
||||
//
|
||||
// CHECK: %[[LOAD0:.*]] = xegpu.load %arg0[%[[CST_OFFSETS0]]], %[[CST_MASK0]]
|
||||
// CHECK-SAME: -> vector<4x16xf32>
|
||||
// CHECK: %[[LOAD1:.*]] = xegpu.load %arg0[%[[CST_OFFSETS1]]], %[[CST_MASK1]]
|
||||
// CHECK-SAME: -> vector<4x16xf32>
|
||||
//
|
||||
// Local reductions
|
||||
// CHECK: %[[NEUTRAL0:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
|
||||
// CHECK: %[[LOCAL_RED0:.*]] = vector.multi_reduction <add>, %[[LOAD0]], %[[NEUTRAL0]] [1] : vector<4x16xf32> to vector<4xf32>
|
||||
// CHECK: %[[NEUTRAL1:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
|
||||
// CHECK: %[[LOCAL_RED1:.*]] = vector.multi_reduction <add>, %[[LOAD1]], %[[NEUTRAL1]] [1] : vector<4x16xf32> to vector<4xf32>
|
||||
//
|
||||
// Shape cast for SLM store
|
||||
// CHECK: %[[SC0:.*]] = vector.shape_cast %[[LOCAL_RED0]] : vector<4xf32> to vector<4x1xf32>
|
||||
// CHECK: %[[SC1:.*]] = vector.shape_cast %[[LOCAL_RED1]] : vector<4xf32> to vector<4x1xf32>
|
||||
//
|
||||
// SLM allocation and mem_desc
|
||||
// CHECK: %[[SLM:.*]] = memref.alloca() : memref<512xi8, 3>
|
||||
// CHECK: %[[MEMDESC:.*]] = xegpu.create_mem_desc %[[SLM]] : memref<512xi8, 3> -> !xegpu.mem_desc<8x16xf32>
|
||||
//
|
||||
// Store to SLM
|
||||
// CHECK: xegpu.store_matrix %[[SC0]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
|
||||
// CHECK: xegpu.store_matrix %[[SC1]], %[[MEMDESC]]{{.*}} : vector<4x1xf32>, !xegpu.mem_desc<8x16xf32>
|
||||
// CHECK: gpu.barrier
|
||||
//
|
||||
// Load from SLM
|
||||
// CHECK: %[[SLM_LOAD0:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
|
||||
// CHECK: %[[SLM_LOAD1:.*]] = xegpu.load_matrix %[[MEMDESC]]{{.*}} -> vector<4x16xf32>
|
||||
//
|
||||
// Final reduction
|
||||
// CHECK: %[[FINAL_NEUTRAL:.*]] = arith.constant dense<0.000000e+00> : vector<4xf32>
|
||||
// CHECK: %[[FINAL_RED0:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD0]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
|
||||
// CHECK: %[[RES0:.*]] = arith.addf %[[FINAL_RED0]], %[[CST_ACC0]] : vector<4xf32>
|
||||
// CHECK: %[[FINAL_RED1:.*]] = vector.multi_reduction <add>, %[[SLM_LOAD1]], %[[FINAL_NEUTRAL]] [1] : vector<4x16xf32> to vector<4xf32>
|
||||
// CHECK: %[[RES1:.*]] = arith.addf %[[FINAL_RED1]], %[[CST_ACC1]] : vector<4xf32>
|
||||
|
||||
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<0> : vector<8x256xindex>
|
||||
%acc = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} dense<0.000000e+00> : vector<8xf32>
|
||||
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>} dense<true> : vector<8x256xi1>
|
||||
%val = xegpu.load %arg0[%offset], %mask <{layout = #xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>}> : memref<2048xf32, 1>, vector<8x256xindex>, vector<8x256xi1> -> vector<8x256xf32>
|
||||
%reduce = vector.multi_reduction <add>, %val, %acc {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 16], sg_data = [4, 16]>, dims = [1]>} [1] : vector<8x256xf32> to vector<8xf32>
|
||||
gpu.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: splat_constant
|
||||
gpu.func @splat_constant() {
|
||||
// CHECK-COUNT-2: %[[CST:.*]] = arith.constant dense<0> : vector<4xindex>
|
||||
|
||||
@ -1,13 +1,4 @@
|
||||
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
|
||||
|
||||
// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 4)>
|
||||
// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 4)>
|
||||
// CHECK-DAG: #map2 = affine_map<()[s0] -> (s0 floordiv 32)>
|
||||
// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 mod 32)>
|
||||
// CHECK-DAG: #map4 = affine_map<()[s0] -> (0)>
|
||||
// CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 32) floordiv 16)>
|
||||
// CHECK-DAG: #map6 = affine_map<()[s0] -> (s0 mod 16)>
|
||||
// CHECK-DAG: #map7 = affine_map<()[s0] -> ((s0 mod 16) floordiv 4)>
|
||||
gpu.module @test_distribution {
|
||||
// CHECK-LABEL: create_nd_tdesc_no_offset
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
|
||||
@ -681,18 +672,9 @@ gpu.module @test_distribution {
|
||||
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
|
||||
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<1x32x32xf32>
|
||||
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
|
||||
// CHECK-DAG: %[[AFF0:.*]] = affine.apply #map2()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFF1:.*]] = affine.apply #map3()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFF2:.*]] = affine.apply #map4()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[ROW:.*]] = arith.muli %[[AFF0]], %[[C1A:.*]] : index
|
||||
// CHECK-DAG: %[[COL0:.*]] = arith.muli %[[AFF1:.*]], %[[C1B:.*]] : index
|
||||
// CHECK-DAG: %[[COL1:.*]] = arith.muli %[[AFF2]], %[[C32A:.*]] : index
|
||||
// CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]][%[[ROW]], %[[COL0]], %[[COL1]]] : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index
|
||||
// CHECK-DAG: xegpu.store_matrix %[[CAST]], %[[MEM_DESC]]{{.*}} : vector<1x1x32xf32>, !xegpu.mem_desc<1x32x32xf32>, index, index, index
|
||||
// CHECK-DAG: gpu.barrier
|
||||
// CHECK-DAG: %[[ROW_L:.*]] = arith.muli %[[AFF0]], %[[C1C:.*]] : index
|
||||
// CHECK-DAG: %[[COL0_L:.*]] = arith.muli %[[AFF1]], %[[C0:.*]] : index
|
||||
// CHECK-DAG: %[[COL1_L:.*]] = arith.muli %[[AFF2]], %[[C32B:.*]] : index
|
||||
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ROW_L]], %[[COL0_L]], %[[COL1_L]]] : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32>
|
||||
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<1x32x32xf32>, index, index, index -> vector<1x32x32xf32>
|
||||
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
|
||||
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [1] : vector<1x32x32xf32> to vector<1x32xf32>
|
||||
// CHECK-DAG: %[[ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x32xf32>
|
||||
@ -725,15 +707,9 @@ gpu.module @test_distribution {
|
||||
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
|
||||
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
|
||||
// CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
|
||||
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
|
||||
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID2]]]
|
||||
// CHECK-DAG: %[[ROW_OFFSET:.*]] = arith.muli %[[AFFINE1]], %[[C1:.*]] : index
|
||||
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[AFFINE2]], %[[C32_1:.*]] : index
|
||||
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[ROW_OFFSET]], %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
|
||||
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
|
||||
// CHECK-DAG: gpu.barrier
|
||||
// CHECK-DAG: %[[ZERO_ROW:.*]] = arith.muli %[[AFFINE1]], %[[C0:.*]] : index
|
||||
// CHECK-DAG: %[[COL_OFFSET2:.*]] = arith.muli %[[AFFINE2]], %[[C32_2:.*]] : index
|
||||
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[ZERO_ROW]], %[[COL_OFFSET2]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
|
||||
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
|
||||
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
|
||||
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
|
||||
// CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<32xf32>
|
||||
@ -761,31 +737,9 @@ gpu.module @test_distribution {
|
||||
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi8, 3>
|
||||
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<256xi8, 3> -> !xegpu.mem_desc<2x2x4x4xf32>
|
||||
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
|
||||
// CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[I0:.*]] = arith.muli %[[AFFINE0]], %[[C1]] : index
|
||||
// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[I1:.*]] = arith.muli %[[AFFINE2]], %[[C1_0]] : index
|
||||
// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[I2:.*]] = arith.muli %[[AFFINE4]], %[[C1_1]] : index
|
||||
// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[I3:.*]] = arith.muli %[[AFFINE5]], %[[C1_2]] : index
|
||||
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[I0]], %[[I1]], %[[I2]], %[[I3]]] : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index
|
||||
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<1x1x1x1xf32>, !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index
|
||||
// CHECK-DAG: gpu.barrier
|
||||
// CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C1_3]] : index
|
||||
// CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C1_4]] : index
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0]] : index
|
||||
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_0]] : index
|
||||
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32>
|
||||
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<2x2x4x4xf32>, index, index, index, index -> vector<1x1x4x4xf32>
|
||||
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<1x1xf32>
|
||||
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<1x1x4x4xf32> to vector<1x1xf32>
|
||||
// CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<1x1xf32>
|
||||
@ -811,23 +765,9 @@ gpu.module @test_distribution {
|
||||
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<65536xi8, 3>
|
||||
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<65536xi8, 3> -> !xegpu.mem_desc<32x32x4x4xf32>
|
||||
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
|
||||
// CHECK-DAG: %[[AFFINE0:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE3:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE4:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[AFFINE5:.*]] = affine.apply {{#map[[:alnum:]_]*}}()[%[[SGID]]]
|
||||
// CHECK-DAG: %[[R0:.*]] = arith.muli %[[AFFINE0]], %[[C16_0:.*]] : index
|
||||
// CHECK-DAG: %[[R1:.*]] = arith.muli %[[AFFINE2]], %[[C16_1:.*]] : index
|
||||
// CHECK-DAG: %[[R2:.*]] = arith.muli %[[AFFINE4]], %[[C1_0:.*]] : index
|
||||
// CHECK-DAG: %[[R3:.*]] = arith.muli %[[AFFINE5]], %[[C1_1:.*]] : index
|
||||
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][%[[R0]], %[[R1]], %[[R2]], %[[R3]]] : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index
|
||||
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]]{{.*}} : vector<16x16x1x1xf32>, !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index
|
||||
// CHECK-DAG: gpu.barrier
|
||||
// CHECK-DAG: %[[L0:.*]] = arith.muli %[[AFFINE0]], %[[C16_2:.*]] : index
|
||||
// CHECK-DAG: %[[L1:.*]] = arith.muli %[[AFFINE2]], %[[C16_3:.*]] : index
|
||||
// CHECK-DAG: %[[L2:.*]] = arith.muli %[[AFFINE4]], %[[C0_0:.*]] : index
|
||||
// CHECK-DAG: %[[L3:.*]] = arith.muli %[[AFFINE5]], %[[C0_1:.*]] : index
|
||||
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[L0]], %[[L1]], %[[L2]], %[[L3]]] : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32>
|
||||
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]]{{.*}} : !xegpu.mem_desc<32x32x4x4xf32>, index, index, index, index -> vector<16x16x4x4xf32>
|
||||
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<16x16xf32>
|
||||
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [2, 3] : vector<16x16x4x4xf32> to vector<16x16xf32>
|
||||
// CHECK-DAG: %[[FINAL_ADD:.*]] = arith.addf %[[FINAL_REDUCE]], %[[CST]] : vector<16x16xf32>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user