[mlir][tblgen] Add PredTypeTrait/PredAttrTrait support (#169153)
This patch adds support for `PredTypeTrait` and `PredAttrTrait` in type
and attribute definitions, enabling declarative predicate-based
verification similar to how `PredOpTrait` works for operations.
## Motivation
In 802bf02 (from 2021), `PredTypeTrait`/`PredAttrTrait` were defined in
TableGen but not implemented in the code generator. Using them causes
mlir-tblgen to crash with an assertion failure when trying to cast
`PredTrait` to `InterfaceTrait`. This patch fixes the crash and
implements the actual verification code generation.
## Usage
Use `$paramName` syntax in predicates to reference type/attribute
parameters:
```tablegen
def MyType : MyDialect_Type<"MyType",
[PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
let parameters = (ins "unsigned":$value);
let mnemonic = "my_type";
let assemblyFormat = "`<` $value `>`";
}
```
This generates verification code in `verifyInvariantsImpl()`:
```cpp
if (!(value > 0)) {
emitError() << "failed to verify that value must be positive";
return ::mlir::failure();
}
```
This commit is contained in:
parent
43faefdb12
commit
254b3b137e
@ -190,9 +190,11 @@ bool AttrOrTypeDef::genVerifyDecl() const {
|
||||
}
|
||||
|
||||
bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
|
||||
return any_of(parameters, [](const AttrOrTypeParameter &p) {
|
||||
return p.getConstraint() != std::nullopt;
|
||||
});
|
||||
return any_of(parameters,
|
||||
[](const AttrOrTypeParameter &p) {
|
||||
return p.getConstraint() != std::nullopt;
|
||||
}) ||
|
||||
any_of(traits, [](const Trait &t) { return isa<PredTrait>(&t); });
|
||||
}
|
||||
|
||||
std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
|
||||
|
||||
@ -22,3 +22,51 @@
|
||||
|
||||
// expected-error @below{{failed to verify 'elementType': VectorElementTypeInterface instance}}
|
||||
"test.type_producer"() : () -> vector<memref<2xf32>>
|
||||
|
||||
// -----
|
||||
|
||||
// Test PredTypeTrait with single parameter - valid case.
|
||||
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait<5>
|
||||
"test.type_producer"() : () -> !test.type_pred_trait<5>
|
||||
|
||||
// -----
|
||||
|
||||
// Test PredTypeTrait with single parameter - invalid case (zero is not positive).
|
||||
// expected-error @below{{failed to verify that value must be positive}}
|
||||
"test.type_producer"() : () -> !test.type_pred_trait<0>
|
||||
|
||||
// -----
|
||||
|
||||
// Test PredTypeTrait with multiple parameters - valid case (5 >= 3).
|
||||
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
|
||||
"test.type_producer"() : () -> !test.type_pred_trait_multi<5, 3>
|
||||
|
||||
// -----
|
||||
|
||||
// Test PredTypeTrait with multiple parameters - edge case (3 >= 3).
|
||||
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
|
||||
"test.type_producer"() : () -> !test.type_pred_trait_multi<3, 3>
|
||||
|
||||
// -----
|
||||
|
||||
// Test PredTypeTrait with multiple parameters - invalid case (2 < 5).
|
||||
// expected-error @below{{failed to verify that value must be at least min}}
|
||||
"test.type_producer"() : () -> !test.type_pred_trait_multi<2, 5>
|
||||
|
||||
// -----
|
||||
|
||||
// Test combined parameter constraint + PredTypeTrait - valid case.
|
||||
// CHECK: "test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
|
||||
"test.type_producer"() : () -> !test.type_pred_trait_combined<3, [1, 2, 3], i32>
|
||||
|
||||
// -----
|
||||
|
||||
// Test combined - parameter type constraint fails (f16 not in [I16, I32]).
|
||||
// expected-error @below{{failed to verify 'elementType': 16-bit signless integer or 32-bit signless integer}}
|
||||
"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2], f16>
|
||||
|
||||
// -----
|
||||
|
||||
// Test combined - PredTypeTrait fails (count 2 != elements.size() 3).
|
||||
// expected-error @below{{failed to verify that count must match number of elements}}
|
||||
"test.type_producer"() : () -> !test.type_pred_trait_combined<2, [1, 2, 3], i16>
|
||||
|
||||
@ -406,6 +406,36 @@ def TestTypeVerification : Test_Type<"TestTypeVerification"> {
|
||||
let assemblyFormat = "`<` $param `>`";
|
||||
}
|
||||
|
||||
// Test type with PredTypeTrait - single parameter predicate.
|
||||
def TestTypePredTrait : Test_Type<"TestTypePredTrait",
|
||||
[PredTypeTrait<"value must be positive", CPred<"$value > 0">>]> {
|
||||
let parameters = (ins "unsigned":$value);
|
||||
let mnemonic = "type_pred_trait";
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
// Test type with PredTypeTrait - two parameter predicate.
|
||||
def TestTypePredTraitMultiParam : Test_Type<"TestTypePredTraitMultiParam",
|
||||
[PredTypeTrait<"value must be at least min",
|
||||
CPred<"$value >= $minValue">>]> {
|
||||
let parameters = (ins "unsigned":$value, "unsigned":$minValue);
|
||||
let mnemonic = "type_pred_trait_multi";
|
||||
let assemblyFormat = "`<` $value `,` $minValue `>`";
|
||||
}
|
||||
|
||||
// Test type combining parameter type constraints with PredTypeTrait.
|
||||
def TestTypePredTraitCombined : Test_Type<"TestTypePredTraitCombined",
|
||||
[PredTypeTrait<"count must match number of elements",
|
||||
CPred<"$count == $elements.size()">>]> {
|
||||
let parameters = (ins
|
||||
"unsigned":$count,
|
||||
ArrayRefParameter<"int64_t">:$elements,
|
||||
AnyTypeOf<[I16, I32]>:$elementType
|
||||
);
|
||||
let mnemonic = "type_pred_trait_combined";
|
||||
let assemblyFormat = "`<` $count `,` `[` $elements `]` `,` $elementType `>`";
|
||||
}
|
||||
|
||||
def TestTypeOpAsmTypeInterface : Test_Type<"TestTypeOpAsmTypeInterface",
|
||||
[DeclareTypeInterfaceMethods<OpAsmTypeInterface, ["getAsmName", "getAlias"]>]> {
|
||||
let mnemonic = "op_asm_type_interface";
|
||||
|
||||
@ -246,12 +246,16 @@ void DefGen::createParentWithTraits() {
|
||||
? strfmt("{0}::{1}", def.getStorageNamespace(),
|
||||
def.getStorageClassName())
|
||||
: strfmt("::mlir::{0}Storage", valueType));
|
||||
SmallVector<std::string> traitNames =
|
||||
llvm::map_to_vector(def.getTraits(), [](auto &trait) {
|
||||
return isa<NativeTrait>(&trait)
|
||||
? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
|
||||
: cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName();
|
||||
});
|
||||
SmallVector<std::string> traitNames;
|
||||
for (auto &trait : def.getTraits()) {
|
||||
// Skip PredTrait as it doesn't generate a C++ trait class.
|
||||
if (isa<PredTrait>(&trait))
|
||||
continue;
|
||||
traitNames.push_back(
|
||||
isa<NativeTrait>(&trait)
|
||||
? cast<NativeTrait>(&trait)->getFullyQualifiedTraitName()
|
||||
: cast<InterfaceTrait>(&trait)->getFullyQualifiedTraitName());
|
||||
}
|
||||
for (auto &traitName : traitNames)
|
||||
defParent.addTemplateParam(traitName);
|
||||
|
||||
@ -386,6 +390,26 @@ void DefGen::emitInvariantsVerifierImpl() {
|
||||
param.getName(), constraint->getSummary())
|
||||
<< "\n";
|
||||
}
|
||||
{
|
||||
// Generate verification for PredTraits.
|
||||
FmtContext traitCtx;
|
||||
for (auto it : llvm::enumerate(def.getParameters())) {
|
||||
// Note: Skip over the first method parameter (`emitError`).
|
||||
traitCtx.addSubst(it.value().getName(),
|
||||
builderParams[it.index() + 1].getName());
|
||||
}
|
||||
for (const Trait &trait : def.getTraits()) {
|
||||
if (auto *t = dyn_cast<PredTrait>(&trait)) {
|
||||
verifier->body() << tgfmt(
|
||||
"if (!($0)) {\n"
|
||||
" emitError() << \"failed to verify that $1\";\n"
|
||||
" return ::mlir::failure();\n"
|
||||
"}\n",
|
||||
&traitCtx, tgfmt(t->getPredTemplate(), &traitCtx), t->getSummary());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
verifier->body() << "return ::mlir::success();";
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user