diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index 0cc556ef5d85..299200788136 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -720,6 +720,14 @@ def Builtin_IntegerAttr : Builtin_Attr<"Integer", "integer", "const APInt &":$value), [{ if (type.isSignlessInteger(1)) return BoolAttr::get(type.getContext(), value.getBoolValue()); + // Validate that the APInt has the correct bit width for the given type. + if (auto intTy = ::llvm::dyn_cast(type)) { + assert(value.getBitWidth() == intTy.getWidth() && + "IntegerAttr::get: APInt bit width must match integer type width"); + } else if (::llvm::isa(type)) { + assert(value.getBitWidth() == IndexType::kInternalStorageBitWidth && + "IntegerAttr::get: APInt bit width must match IndexType internal storage bit width"); + } return $_get(type.getContext(), type, value); }]>, AttrBuilder<(ins "const APSInt &":$value), [{ diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 404aa8c0dcf3..900cacabd592 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -523,4 +523,46 @@ TEST(CopyCountAttr, PrintStripped) { EXPECT_EQ(str, "|#test.copy_count|[copy_count]"); } +//===----------------------------------------------------------------------===// +// IntegerAttr +//===----------------------------------------------------------------------===// + +TEST(IntegerAttrTest, CorrectBitWidths) { + MLIRContext context; + + // Correct APInt width for i32. + IntegerType i32 = IntegerType::get(&context, 32); + IntegerAttr attr32 = IntegerAttr::get(i32, APInt(32, 42)); + EXPECT_EQ(attr32.getType(), i32); + EXPECT_EQ(attr32.getValue().getBitWidth(), 32u); + EXPECT_EQ(attr32.getInt(), 42); + + // Correct APInt width for index type. + IndexType indexTy = IndexType::get(&context); + IntegerAttr attrIdx = + IntegerAttr::get(indexTy, APInt(IndexType::kInternalStorageBitWidth, 5)); + EXPECT_EQ(attrIdx.getType(), indexTy); + EXPECT_EQ(attrIdx.getValue().getBitWidth(), + (unsigned)IndexType::kInternalStorageBitWidth); +} + +#ifndef NDEBUG +TEST(IntegerAttrDeathTest, WrongBitWidthForIntegerType) { + MLIRContext context; + IntegerType i32 = IntegerType::get(&context, 32); + // APInt(8, 1) has bit width 8, but i32 requires 32. + EXPECT_DEATH(IntegerAttr::get(i32, APInt(8, 1)), + "APInt bit width must match integer type width"); +} + +TEST(IntegerAttrDeathTest, WrongBitWidthForIndexType) { + MLIRContext context; + IndexType indexTy = IndexType::get(&context); + // APInt(1, 1) has bit width 1, but index type requires 64. + EXPECT_DEATH( + IntegerAttr::get(indexTy, APInt(1, 1)), + "APInt bit width must match IndexType internal storage bit width"); +} +#endif // NDEBUG + } // namespace