[MLIR][NVVM] Update TMA tensor prefetch Op (#153464)

This patch updates the TMA Tensor prefetch Op
to add support for im2col_w/w128 and tile_gather4 modes.
This completes support for all modes available in Blackwell.
* lit tests are added for all possible combinations.
* The invalid tests are moved to a separate file with more coverage.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
This commit is contained in:
Durgadoss R 2025-08-22 12:51:29 +05:30 committed by GitHub
parent 5050da7ba1
commit 36dc6146b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 292 additions and 117 deletions

View File

@ -2302,6 +2302,56 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops // NVVM TMA Ops
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// List of modes supported for TMA Load and Prefetch Ops
def TMALoadModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
def TMALoadModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
def TMALoadModeIm2ColW : I32EnumAttrCase<"IM2COL_W", 2, "im2col_w">;
def TMALoadModeIm2ColW128 : I32EnumAttrCase<"IM2COL_W_128", 3, "im2col_w_128">;
def TMALoadModeTileGather4 : I32EnumAttrCase<"TILE_GATHER4", 4, "tile_gather4">;
def TMALoadMode : I32EnumAttr<"TMALoadMode", "NVVM TMA Load Mode",
[TMALoadModeTile, TMALoadModeIm2Col,
TMALoadModeIm2ColW, TMALoadModeIm2ColW128,
TMALoadModeTileGather4]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def TMALoadModeAttr : EnumAttr<NVVM_Dialect, TMALoadMode, "tma_load_mode"> {
let summary = "List of Load-Modes supported for TMA Tensor Ops";
let description = [{
TMA Tensor Ops support the following modes, when copying data from
global memory to shared memory (i.e. load):
Tile Mode: It's the default mode. The source multi-dimensional tensor
layout is preserved at the destination.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode)
Im2col Mode: This mode is used when `im2colOffsets` operands are present.
The elements in the Bounding Box of the source tensor are rearranged into
columns at the destination. In this mode, the tensor has to be at least
3-dimensional. The number of `im2colOffsets` is `dims - 2` where `dims`
is the dimension of the tensor.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode)
Im2col_W Mode: This mode is similar to Im2Col mode with the restriction that
elements are accessed across the W dimension only. The number of `im2colOffsets`
are always two, referred as `wHalo` and `wOffset`.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
Im2col_W_128 Mode: This mode is similar to Im2Col_W mode with the number of
elements accessed across the W dimension is always 128 only.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
Tile_Gather4 Mode: This mode is similar to Tile mode but works only on 2D tensor.
In gather4 mode, four rows in the source 2D tensor are combined to form a single
2D tensor at the destination. This mode requires five co-ordinates. The first one
represents the column-index followed by four row indices.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-scatter4-gather4-modes)
}];
let assemblyFormat = "`<` $value `>`";
}
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">, def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> { Arguments<(ins )> {
let assemblyFormat = "attr-dict"; let assemblyFormat = "attr-dict";
@ -2570,23 +2620,16 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
def NVVM_CpAsyncBulkTensorPrefetchOp : def NVVM_CpAsyncBulkTensorPrefetchOp :
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> { NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
let arguments = (ins let arguments = (ins
LLVM_AnyPointer:$tmaDescriptor, LLVM_PointerGeneric:$tmaDescriptor,
Variadic<I32>:$coordinates, Variadic<I32>:$coordinates,
Variadic<I16>:$im2colOffsets, Variadic<I16>:$im2colOffsets,
DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
Optional<I64>:$l2CacheHint); Optional<I64>:$l2CacheHint);
let description = [{ let description = [{
Initiates an asynchronous prefetch operation on the tensor data from global Initiates an asynchronous prefetch operation on the tensor data from global
memory to L2 cache. memory to L2 cache. This Op supports all the load modes specified in
`TMALoadMode`.
The Op has two modes:
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
layout is preserved at the destination.
2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
the elements in the Bounding Box of the source tensor are rearranged into
columns at the destination. In this mode, the tensor has to be at least
3-dimensional.
The `l2CacheHint` operand is optional, and it is used to specify cache The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access. eviction policy that may be used during the memory access.
@ -2603,34 +2646,17 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
}]; }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col); static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
}]; }];
let hasVerifier = 1; let hasVerifier = 1;
string llvmBuilder = [{ string llvmBuilder = [{
// Arguments to the intrinsic: auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
// tmaDesc, tensorDims, im2colOffsets *op, moduleTranslation, builder);
// cache_hint(if applicable) and flag(boolean) createIntrinsicCall(builder, id, args);
llvm::SmallVector<llvm::Value *> translatedOperands;
translatedOperands.push_back($tmaDescriptor);
for (auto v : op.getCoordinates())
translatedOperands.push_back(moduleTranslation.lookupValue(v));
for (auto v : op.getIm2colOffsets())
translatedOperands.push_back(moduleTranslation.lookupValue(v));
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
bool isCacheHint = op.getL2CacheHint() ? true : false;
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
translatedOperands.push_back(builder.getInt1(isCacheHint));
auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
createIntrinsicCall(builder, intId, translatedOperands);
}]; }];
} }

View File

@ -50,7 +50,6 @@ using namespace NVVM;
// This verifier is shared among the following Ops: // This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce) // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
bool isIm2Col, bool isIm2Col,
@ -98,11 +97,47 @@ LogicalResult CpAsyncOp::verify() {
return success(); return success();
} }
// This verify params can be shared across TMA Load and Prefetch Ops.
static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff,
TMALoadMode mode, Location loc) {
if (tensorDims < 1 || tensorDims > 5)
return emitError(loc, "expects coordinates between 1 to 5 dimension");
auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
size_t expectedIm2colOff) -> LogicalResult {
if (isIm2col && (tensorDims < 3))
return emitError(loc)
<< "to use " << stringifyEnum(mode)
<< " mode, the tensor has to be at least 3-dimensional";
if (numIm2colOff != expectedIm2colOff)
return emitError(loc) << " im2col offsets expected " << expectedIm2colOff
<< " (provided " << numIm2colOff << ")";
return success();
};
switch (mode) {
case TMALoadMode::TILE:
return checkTMALoadParams(mode, false, 0);
case TMALoadMode::IM2COL:
return checkTMALoadParams(mode, true, tensorDims - 2);
case TMALoadMode::IM2COL_W:
case TMALoadMode::IM2COL_W_128:
return checkTMALoadParams(mode, true, 2);
case TMALoadMode::TILE_GATHER4:
return (tensorDims == 5)
? checkTMALoadParams(mode, false, 0)
: emitError(loc, "Gather4 mode expects 5 coordinates");
default:
return emitError(loc, "Invalid LoadMode in CpAsyncBulkTensorPrefetchOp.");
}
return success();
}
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
size_t numIm2ColOffsets = getIm2colOffsets().size(); return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
bool isIm2Col = numIm2ColOffsets > 0; getMode(), getLoc());
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
numIm2ColOffsets, getLoc());
} }
LogicalResult CpAsyncBulkTensorReduceOp::verify() { LogicalResult CpAsyncBulkTensorReduceOp::verify() {
@ -1435,28 +1470,57 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)}; return {id, std::move(args)};
} }
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
bool isIm2Col) { Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
switch (tensorDims) { auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
case 1: llvm::SmallVector<llvm::Value *> args;
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
case 2: // Fill the Intrinsic Args
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d; args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
case 3:
return isIm2Col for (auto v : thisOp.getCoordinates())
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d args.push_back(mt.lookupValue(v));
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d; for (auto v : thisOp.getIm2colOffsets())
case 4: args.push_back(mt.lookupValue(v));
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d mlir::Value cacheHint = thisOp.getL2CacheHint();
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d; const bool hasCacheHint = static_cast<bool>(cacheHint);
case 5: llvm::Value *i64Unused =
return isIm2Col llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d; args.push_back(builder.getInt1(hasCacheHint));
default:
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp."); const unsigned NI = llvm::Intrinsic::not_intrinsic;
} static constexpr llvm::Intrinsic::ID IDTable[][6] = {
{NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d},
{NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d},
{NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d},
{NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d},
{NI, NI, NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}};
static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
"TMALoadModes must match number of rows in IDTable");
size_t mode = static_cast<size_t>(thisOp.getMode());
size_t dim = thisOp.getCoordinates().size();
llvm::Intrinsic::ID id = IDTable[mode][dim];
if (id == llvm::Intrinsic::not_intrinsic)
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
return {id, std::move(args)};
} }
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \ #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \

View File

@ -1,70 +1,123 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// CHECK-LABEL: @tma_bulk_prefetch
llvm.func @tma_bulk_prefetch(%src : !llvm.ptr<1>, %size : i32, %ch : i64) { llvm.func @tma_bulk_prefetch(%src : !llvm.ptr<1>, %size : i32, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 0, i1 false) // CHECK-LABEL: define void @tma_bulk_prefetch(ptr addrspace(1) %0, i32 %1, i64 %2) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 %2, i1 true)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1> nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>
nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1> nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
llvm.return llvm.return
} }
// CHECK-LABEL: @tma_prefetch_1d
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) { llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 0, i1 false) // CHECK-LABEL: define void @tma_prefetch_1d(ptr %0, i32 %1, i64 %2) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true) // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 %2, i1 true)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
llvm.return llvm.return
} }
// CHECK-LABEL: @tma_prefetch_2d
llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) { llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false) // CHECK-LABEL: define void @tma_prefetch_2d(ptr %0, i32 %1, i32 %2, i64 %3) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 0, i1 false)
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 %3, i1 true)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] {mode = #nvvm.tma_load_mode<tile>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
llvm.return llvm.return
} }
// CHECK-LABEL: @tma_prefetch_3d llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) { // CHECK-LABEL: define void @tma_prefetch_3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false) // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 0, i1 false)
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 %6, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 %6, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 0, i1 false) nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true) nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return llvm.return
} }
// CHECK-LABEL: @tma_prefetch_4d
llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) { llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false) // CHECK-LABEL: define void @tma_prefetch_4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 %7, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 0, i1 false) nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true) nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return llvm.return
} }
// CHECK-LABEL: @tma_prefetch_5d
llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) { llvm.func @tma_prefetch_5d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false) // CHECK-LABEL: define void @tma_prefetch_5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 %9) {
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true) // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %9, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i16 %8, i64 %9, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 %9, i1 true)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.5d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i16 %6, i16 %7, i64 %9, i1 true)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] l2_cache_hint = %ch : !llvm.ptr
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 0, i1 false) nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
// CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.5d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true) nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1, %off2] l2_cache_hint = %ch : !llvm.ptr nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return
}
llvm.func @tma_prefetch_gather4_2d(%tma_desc : !llvm.ptr, %x0 : i32, %y1 : i32, %y2 : i32, %y3 : i32, %y4 : i32, %ch : i64) {
// CHECK-LABEL: define void @tma_prefetch_gather4_2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6) {
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.gather4.2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 0, i1 false)
// CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.gather4.2d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i64 %6, i1 true)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
llvm.return llvm.return
} }

View File

@ -0,0 +1,56 @@
// RUN: mlir-translate -verify-diagnostics -split-input-file -mlir-to-llvmir %s
llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
// expected-error @below {{expects coordinates between 1 to 5 dimension}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
// expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
// expected-error @below {{im2col offsets expected 3 (provided 2)}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_prefetch_3d_im2col_w(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16) {
// expected-error @below {{im2col offsets expected 2 (provided 1)}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_prefetch_4d_im2col_w_128(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16) {
// expected-error @below {{im2col offsets expected 2 (provided 1)}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_prefetch_gather4_3d(%tma_desc : !llvm.ptr, %d0 : i32) {
// expected-error @below {{Gather4 mode expects 5 coordinates}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_prefetch_gather4_2d(%tma_desc : !llvm.ptr, %x0 : i32, %y1 : i32, %y2 : i32, %y3 : i32, %y4 : i32, %off0 : i16, %ch : i64) {
// expected-error @below {{im2col offsets expected 0 (provided 1)}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%x0, %y1, %y2, %y3, %y4] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<tile_gather4>} : !llvm.ptr
llvm.return
}

View File

@ -104,30 +104,6 @@ llvm.func @nvvm_fence_proxy_release() {
// ----- // -----
llvm.func @tma_prefetch_0d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
// expected-error @below {{expects coordinates between 1 to 5 dimension}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[] : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_prefetch_2d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %off0 : i16, %ch : i64) {
// expected-error @below {{to use im2col mode, the tensor has to be at least 3-dimensional}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_prefetch_5d_im2col(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %d4 : i32, %off0 : i16, %off1 : i16, %off2 : i16, %ch : i64) {
// expected-error @below {{im2col offsets must be 2 less than number of coordinates}}
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3, %d4] im2col[%off0, %off1] : !llvm.ptr
llvm.return
}
// -----
llvm.func @tma_reduce_0d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %ch : i64) { llvm.func @tma_reduce_0d(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %ch : i64) {
// expected-error @below {{expects coordinates between 1 to 5 dimension}} // expected-error @below {{expects coordinates between 1 to 5 dimension}}
nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[] {redKind = #nvvm.tma_redux_kind<add>}: !llvm.ptr, !llvm.ptr<3> nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[] {redKind = #nvvm.tma_redux_kind<add>}: !llvm.ptr, !llvm.ptr<3>