[mlir][bytecode] Fix crashes when reading bytecode with unsupported types (#186354)

When using test-kind=2 in the bytecode roundtrip test, integer types
(i32) are replaced by a custom type (TestI32Type) via a type callback.
This exposed two crash scenarios:

1. Reading IntegerAttr with an unsupported type: `getIntegerBitWidth`
returns 0 for unsupported types and emits an error, but
`readAPIntWithKnownWidth` would proceed to call
`reader.readAPIntWithKnownWidth(0)`, creating a zero-width APInt with a
potentially non-zero value. Fix: early-return failure when `bitWidth ==
0`.

2. Reading VectorType with an unsupported element type:
`VectorType::get` asserts that the element type implements
VectorElementTypeInterface. When the element type is replaced by a
custom type that doesn't implement this interface, the program crashes.
Fix: use `VectorType::getChecked` with a diagnostic emitter lambda
instead of `get<VectorType>` in the bytecode builder.

Fixes #128312
This commit is contained in:
Mehdi Amini 2026-03-17 11:38:12 +01:00 committed by GitHub
parent 35118457ab
commit e904d559c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 74 additions and 25 deletions

View File

@ -296,6 +296,9 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
// Use getChecked to produce a null type (and emit a diagnostic) instead of
// asserting when the element type does not implement VectorElementTypeInterface.
let cBuilder = "VectorType::getChecked([&]() { return reader.emitError(\"invalid vector type\"); }, shape, elementType)";
}
def VectorTypeWithScalableDims : DialectType<(type
@ -305,7 +308,9 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
// Use getChecked to produce a null type (and emit a diagnostic) instead of
// asserting when the element type does not implement VectorElementTypeInterface.
let cBuilder = "VectorType::getChecked([&]() { return reader.emitError(\"invalid vector type\"); }, shape, elementType, scalableDims)";
}
}

View File

@ -33,23 +33,27 @@ namespace {
// TODO: Move these to separate file.
// Returns the bitwidth if known, else return 0.
static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) {
if (auto intType = dyn_cast<IntegerType>(type)) {
// Returns the bitwidth if known, else return std::nullopt.
static std::optional<unsigned> getIntegerBitWidth(DialectBytecodeReader &reader,
Type type) {
if (auto intType = dyn_cast<IntegerType>(type))
return intType.getWidth();
}
if (llvm::isa<IndexType>(type)) {
if (llvm::isa<IndexType>(type))
return IndexType::kInternalStorageBitWidth;
}
reader.emitError()
<< "expected integer or index type for IntegerAttr, but got: " << type;
return 0;
return std::nullopt;
}
static LogicalResult readAPIntWithKnownWidth(DialectBytecodeReader &reader,
Type type, FailureOr<APInt> &val) {
unsigned bitWidth = getIntegerBitWidth(reader, type);
val = reader.readAPIntWithKnownWidth(bitWidth);
std::optional<unsigned> bitWidth = getIntegerBitWidth(reader, type);
// getIntegerBitWidth returns std::nullopt and emits an error for unsupported
// types. Bail out early to avoid creating a zero-width APInt with a non-zero
// value.
if (!bitWidth)
return failure();
val = reader.readAPIntWithKnownWidth(*bitWidth);
return val;
}

View File

@ -1,15 +0,0 @@
// RUN: not mlir-opt %s --test-bytecode-roundtrip="test-kind=2" 2>&1 | FileCheck %s
// Regression test: test-kind=2 replaces i32 with !test.i32 (a type that does
// not implement DenseElementTypeInterface). This should produce a proper error
// instead of an assertion failure when deserializing DenseTypedElementsAttr.
// CHECK: DenseTypedElementsAttr element type must implement DenseElementTypeInterface, but got: '!test.i32'
// CHECK: failed to read bytecode
module {
func.func @test() -> tensor<10xi32> {
%0 = arith.constant dense<42> : tensor<10xi32>
return %0 : tensor<10xi32>
}
}

View File

@ -0,0 +1,55 @@
// RUN: not mlir-opt %s -split-input-file --test-bytecode-roundtrip="test-kind=2" 2>&1 | FileCheck %s
// Tests that proper errors are emitted (rather than crashes) when the type
// callback replaces types with ones that are incompatible with built-in types
// and attributes (test-kind=2 replaces i32 with !test.i32).
// CHECK: expected integer or index type for IntegerAttr, but got: '!test.i32'
// CHECK: failed to read bytecode
// IntegerAttr whose type is replaced by one that is neither IntegerType nor
// IndexType previously crashed with an APInt assertion.
module {
func.func @integer_attr_unsupported_type() {
%c = arith.constant 1 : i32
return
}
}
// -----
// CHECK: failed to verify 'elementType': VectorElementTypeInterface instance
// CHECK: failed to read bytecode
// Fixed-size VectorType whose element type is replaced by one that does not
// implement VectorElementTypeInterface previously crashed in VectorType::get.
module {
func.func @vector_unsupported_elem_type() {
%cst = arith.constant dense<42> : vector<3xi32>
return
}
}
// -----
// CHECK: failed to verify 'elementType': VectorElementTypeInterface instance
// CHECK: failed to read bytecode
// Scalable VectorType whose element type is replaced by one that does not
// implement VectorElementTypeInterface exercises the VectorTypeWithScalableDims
// bytecode path.
module {
func.func @scalable_vector_unsupported_elem_type(%v : vector<[3]xi32>) {
return
}
}
// -----
// CHECK: DenseTypedElementsAttr element type must implement DenseElementTypeInterface, but got: '!test.i32'
// CHECK: failed to read bytecode
// DenseTypedElementsAttr whose element type is replaced by one that does not
// implement DenseElementTypeInterface previously crashed with an assertion.
module {
func.func @dense_elem_unsupported_type() -> tensor<10xi32> {
%0 = arith.constant dense<42> : tensor<10xi32>
return %0 : tensor<10xi32>
}
}