[mlir][tblgen] Add PredTypeTrait/PredAttrTrait support (#169153)

This patch adds support for `PredTypeTrait` and `PredAttrTrait` in type
and attribute definitions, enabling declarative predicate-based
verification similar to how `PredOpTrait` works for operations.

  ## Motivation

In 802bf02 (from 2021), `PredTypeTrait`/`PredAttrTrait` were defined in
TableGen but not implemented in the code generator. Using them causes
mlir-tblgen to crash with an assertion failure when trying to cast
`PredTrait` to `InterfaceTrait`. This patch fixes the crash and
implements the actual verification code generation.

  ## Usage

Use `$paramName` syntax in predicates to reference type/attribute
parameters:

  ```tablegen
  def MyType : MyDialect_Type<"MyType",
      [PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
    let parameters = (ins "unsigned":$value);
    let mnemonic = "my_type";
    let assemblyFormat = "`<` $value `>`";
  }
  ```

  This generates verification code in `verifyInvariantsImpl()`:
```cpp
  if (!(value > 0)) {
    emitError() << "failed to verify that value must be positive";
    return ::mlir::failure();
  }
  ```
This commit is contained in:
Tim Noack 2026-02-03 23:00:56 +01:00 committed by GitHub
parent 43faefdb12
commit 254b3b137e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 113 additions and 9 deletions

View File

@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
}
bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
return any_of(parameters, [](const AttrOrTypeParameter &p) {
return p.getConstraint() != std::nullopt;
});
return any_of(parameters,
[](const AttrOrTypeParameter &p) {
return p.getConstraint() != std::nullopt;
}) ||
any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
}
std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {

View File

@ -22,3 +22,51 @@
// expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
"test.type_producer"() : () -> vector<memref<2xf32>>
// -----
// Test PredTypeTrait with single parameter - valid case.
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
"test.type_producer"() : () -> !test.type_pred_trait<5>
// -----
// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
// expected-error @below{{failed to verify that value must be positive}}
"test.type_producer"() : () -> !test.type_pred_trait<0>
// -----
// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
// -----
// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
// -----
// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
// expected-error @below{{failed to verify that value must be at least min}}
"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>
// -----
// Test combined parameter constraint + PredTypeTrait - valid case.
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
// -----
// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>
// -----
// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
// expected-error @below{{failed to verify that count must match number of elements}}
"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>

View File

@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let assemblyFormat = "`<` $param `>`";
}
// Test type with PredTypeTrait - single parameter predicate.
def TestTypePredTrait : Test_Type<"TestTypePredTrait",
[PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
let parameters = (ins "unsigned":$value);
let mnemonic = "type_pred_trait";
let assemblyFormat = "`<` $value `>`";
}
// Test type with PredTypeTrait - two parameter predicate.
def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
[PredTypeTrait<"value must be at least min",
CPred<"$value >= $minValue">>]> {
let parameters = (ins "unsigned":$value, "unsigned":$minValue);
let mnemonic = "type_pred_trait_multi";
let assemblyFormat = "`<` $value `,` $minValue `>`";
}
// Test type combining parameter type constraints with PredTypeTrait.
def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
[PredTypeTrait<"count must match number of elements",
CPred<"$count == $elements.size()">>]> {
let parameters = (ins
"unsigned":$count,
ArrayRefParameter<"int64_t">:$elements,
AnyTypeOf<[I16, I32]>:$elementType
);
let mnemonic = "type_pred_trait_combined";
let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
}
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
let mnemonic = "op_asm_type_interface";

View File

@ -246,12 +246,16 @@ void DefGen::createParentWithTraits() {
? strfmt("{0}::{1}", def.getStorageNamespace(),
def.getStorageClassName())
: strfmt("::mlir::{0}Storage", valueType));
SmallVector<std::string> traitNames =
llvm::map_to_vector(def.getTraits(), [](auto &trait) {
return isa<NativeTrait>(&trait)
? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
: cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
});
SmallVector<std::string> traitNames;
for (auto &trait : def.getTraits()) {
// Skip PredTrait as it doesn't generate a C++ trait class.
if (isa<PredTrait>(&trait))
continue;
traitNames.push_back(
isa<NativeTrait>(&trait)
? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
: cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
}
for (auto &traitName : traitNames)
defParent.addTemplateParam(traitName);
@ -386,6 +390,26 @@ void DefGen::emitInvariantsVerifierImpl() {
param.getName(), constraint->getSummary())
<< "\n";
}
{
// Generate verification for PredTraits.
FmtContext traitCtx;
for (auto it : llvm::enumerate(def.getParameters())) {
// Note: Skip over the first method parameter (`emitError`).
traitCtx.addSubst(it.value().getName(),
builderParams[it.index() + 1].getName());
}
for (const Trait &trait : def.getTraits()) {
if (auto *t = dyn_cast<PredTrait>(&trait)) {
verifier->body() << tgfmt(
"if (!($0)) {\n"
" emitError() << \"failed to verify that $1\";\n"
" return ::mlir::failure();\n"
"}\n",
&traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
}
}
}
verifier->body() << "return ::mlir::success();";
}