From e904d559c5ae071a5d04cdf003021cb3df2bb1a4 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 17 Mar 2026 11:38:12 +0100 Subject: [PATCH] [mlir][bytecode] Fix crashes when reading bytecode with unsupported types (#186354) When using test-kind=2 in the bytecode roundtrip test, integer types (i32) are replaced by a custom type (TestI32Type) via a type callback. This exposed two crash scenarios: 1. Reading IntegerAttr with an unsupported type: `getIntegerBitWidth` returns 0 for unsupported types and emits an error, but `readAPIntWithKnownWidth` would proceed to call `reader.readAPIntWithKnownWidth(0)`, creating a zero-width APInt with a potentially non-zero value. Fix: early-return failure when `bitWidth == 0`. 2. Reading VectorType with an unsupported element type: `VectorType::get` asserts that the element type implements VectorElementTypeInterface. When the element type is replaced by a custom type that doesn't implement this interface, the program crashes. Fix: use `VectorType::getChecked` with a diagnostic emitter lambda instead of `get` in the bytecode builder. Fixes #128312 --- .../include/mlir/IR/BuiltinDialectBytecode.td | 7 ++- mlir/lib/IR/BuiltinDialectBytecode.cpp | 22 +++++--- .../invalid-dense-elem-type-interface.mlir | 15 ----- .../invalid/invalid-type-remapping.mlir | 55 +++++++++++++++++++ 4 files changed, 74 insertions(+), 25 deletions(-) delete mode 100644 mlir/test/Bytecode/invalid/invalid-dense-elem-type-interface.mlir create mode 100644 mlir/test/Bytecode/invalid/invalid-type-remapping.mlir diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td index 53a859e32d64..64cc8a8ff5e2 100644 --- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -296,6 +296,9 @@ def VectorType : DialectType<(type Type:$elementType )> { let printerPredicate = "!$_val.isScalable()"; + // Use getChecked to produce a null type (and emit a diagnostic) instead of + // asserting when the element type does not implement VectorElementTypeInterface. + let cBuilder = "VectorType::getChecked([&]() { return reader.emitError(\"invalid vector type\"); }, shape, elementType)"; } def VectorTypeWithScalableDims : DialectType<(type @@ -305,7 +308,9 @@ def VectorTypeWithScalableDims : DialectType<(type )> { let printerPredicate = "$_val.isScalable()"; // Note: order of serialization does not match order of builder. - let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)"; + // Use getChecked to produce a null type (and emit a diagnostic) instead of + // asserting when the element type does not implement VectorElementTypeInterface. + let cBuilder = "VectorType::getChecked([&]() { return reader.emitError(\"invalid vector type\"); }, shape, elementType, scalableDims)"; } } diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp index f7430784dd22..14dc66518409 100644 --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -33,23 +33,27 @@ namespace { // TODO: Move these to separate file. -// Returns the bitwidth if known, else return 0. -static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) { - if (auto intType = dyn_cast(type)) { +// Returns the bitwidth if known, else return std::nullopt. +static std::optional getIntegerBitWidth(DialectBytecodeReader &reader, + Type type) { + if (auto intType = dyn_cast(type)) return intType.getWidth(); - } - if (llvm::isa(type)) { + if (llvm::isa(type)) return IndexType::kInternalStorageBitWidth; - } reader.emitError() << "expected integer or index type for IntegerAttr, but got: " << type; - return 0; + return std::nullopt; } static LogicalResult readAPIntWithKnownWidth(DialectBytecodeReader &reader, Type type, FailureOr &val) { - unsigned bitWidth = getIntegerBitWidth(reader, type); - val = reader.readAPIntWithKnownWidth(bitWidth); + std::optional bitWidth = getIntegerBitWidth(reader, type); + // getIntegerBitWidth returns std::nullopt and emits an error for unsupported + // types. Bail out early to avoid creating a zero-width APInt with a non-zero + // value. + if (!bitWidth) + return failure(); + val = reader.readAPIntWithKnownWidth(*bitWidth); return val; } diff --git a/mlir/test/Bytecode/invalid/invalid-dense-elem-type-interface.mlir b/mlir/test/Bytecode/invalid/invalid-dense-elem-type-interface.mlir deleted file mode 100644 index f076dcb9b2f1..000000000000 --- a/mlir/test/Bytecode/invalid/invalid-dense-elem-type-interface.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: not mlir-opt %s --test-bytecode-roundtrip="test-kind=2" 2>&1 | FileCheck %s - -// Regression test: test-kind=2 replaces i32 with !test.i32 (a type that does -// not implement DenseElementTypeInterface). This should produce a proper error -// instead of an assertion failure when deserializing DenseTypedElementsAttr. - -// CHECK: DenseTypedElementsAttr element type must implement DenseElementTypeInterface, but got: '!test.i32' -// CHECK: failed to read bytecode - -module { - func.func @test() -> tensor<10xi32> { - %0 = arith.constant dense<42> : tensor<10xi32> - return %0 : tensor<10xi32> - } -} diff --git a/mlir/test/Bytecode/invalid/invalid-type-remapping.mlir b/mlir/test/Bytecode/invalid/invalid-type-remapping.mlir new file mode 100644 index 000000000000..44d0a4eb8bb4 --- /dev/null +++ b/mlir/test/Bytecode/invalid/invalid-type-remapping.mlir @@ -0,0 +1,55 @@ +// RUN: not mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=2" 2>&1 | FileCheck %s + +// Tests that proper errors are emitted (rather than crashes) when the type +// callback replaces types with ones that are incompatible with built-in types +// and attributes (test-kind=2 replaces i32 with !test.i32). + +// CHECK: expected integer or index type for IntegerAttr, but got: '!test.i32' +// CHECK: failed to read bytecode +// IntegerAttr whose type is replaced by one that is neither IntegerType nor +// IndexType — previously crashed with an APInt assertion. +module { + func.func @integer_attr_unsupported_type() { + %c = arith.constant 1 : i32 + return + } +} + +// ----- + +// CHECK: failed to verify 'elementType': VectorElementTypeInterface instance +// CHECK: failed to read bytecode +// Fixed-size VectorType whose element type is replaced by one that does not +// implement VectorElementTypeInterface — previously crashed in VectorType::get. +module { + func.func @vector_unsupported_elem_type() { + %cst = arith.constant dense<42> : vector<3xi32> + return + } +} + +// ----- + +// CHECK: failed to verify 'elementType': VectorElementTypeInterface instance +// CHECK: failed to read bytecode +// Scalable VectorType whose element type is replaced by one that does not +// implement VectorElementTypeInterface — exercises the VectorTypeWithScalableDims +// bytecode path. +module { + func.func @scalable_vector_unsupported_elem_type(%v : vector<[3]xi32>) { + return + } +} + +// ----- + +// CHECK: DenseTypedElementsAttr element type must implement DenseElementTypeInterface, but got: '!test.i32' +// CHECK: failed to read bytecode +// DenseTypedElementsAttr whose element type is replaced by one that does not +// implement DenseElementTypeInterface — previously crashed with an assertion. +module { + func.func @dense_elem_unsupported_type() -> tensor<10xi32> { + %0 = arith.constant dense<42> : tensor<10xi32> + return %0 : tensor<10xi32> + } +}