diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 34c95e314f38..8474244c7d7c 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -422,6 +422,12 @@ std::optional verifyTmaDescriptorWithMemref( << descMemref << " != " << dstMemref; } + int lastDimBytes = + descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8; + if (lastDimBytes % 16 != 0) { + return op->emitError() << "the bytes in the last dimension of the tensor " + "map must be a multiple of 16"; + } return std::nullopt; } diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir index 2b64fa4a0117..f735e3f8cc62 100644 --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -378,3 +378,14 @@ func.func @check_matrixC_dim(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %ar %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16> return %d : vector<2x2xf16> } + +// ----- + +!desc = !nvgpu.tensormap.descriptor, swizzle=none, l2promo = none, oob = zero, interleave = none> +!mbarrier = !nvgpu.mbarrier.group> +func.func @tma_last_dim_bytes(%desc: !desc, %buffer: memref<32x8xi8,3>, %mbarrier: !mbarrier) { + %c0 = arith.constant 0 : index + // expected-error @+1 {{the bytes in the last dimension of the tensor map must be a multiple of 16}} + nvgpu.tma.async.load %desc[%c0, %c0], %mbarrier[%c0] to %buffer : !desc, !mbarrier -> memref<32x8xi8,3> + return +}