diff --git a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp index 957b9632422a..c4d00b48a49d 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp @@ -166,6 +166,9 @@ mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, // Check if this is necessary given the assumption of 128b accesses: // If dim[rank-1] is small enough to fit 8 rows in a 128B line. const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); + if (ShapedType::isDynamic(rowSize) || rowSize == 0) + return failure(); + const int64_t rowsPerLine = (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / rowSize; diff --git a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir index 7477e1872867..596d24b94811 100644 --- a/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir +++ b/mlir/test/Dialect/NVGPU/optimize-shared-memory.mlir @@ -248,3 +248,14 @@ func.func @test_0_d() -> memref> { %alloc = memref.alloc() : memref> return %alloc : memref> } + +// ----- + +// Ensure the case with zero or dynamic dim not crash. + +// CHECK-LABEL: func @test_dynamic_and_zero_dim +func.func @test_dynamic_and_zero_dim(%arg0 : index) { + %alloc = memref.alloc() : memref<0xf32, 3> + %alloc_1 = memref.alloc(%arg0) : memref + return +}