[mlir][irdl] Add support for basic structural constraints in tblgen-to-irdl (#82862)

This commit is contained in:
Fehr Mathieu 2024-03-05 14:53:59 +00:00 committed by GitHub
parent 1c2b79add6
commit a64975f966
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 134 additions and 22 deletions

View File

@ -168,24 +168,28 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
BuildableType<"$_builder.getType<::mlir::NoneType>()">;
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string summary = "",
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy any of the allowed types' conditions.
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypes, t.summary), " or "),
!interleave(!foreach(t, allowedTypeList, t.summary), " or "),
summary),
cppClassName>;
cppClassName> {
list<Type> allowedTypes = allowedTypeList;
}
// A type that satisfies the constraints of all given types.
class AllOfType<list<Type> allowedTypes, string summary = "",
class AllOfType<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy all of the allowedf types' conditions.
And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
// Satisfy all of the allowed types' conditions.
And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypes, t.summary), " and "),
!interleave(!foreach(t, allowedTypeList, t.summary), " and "),
summary),
cppClassName>;
cppClassName> {
list<Type> allowedTypes = allowedTypeList;
}
// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",

View File

@ -24,7 +24,7 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
}
// CHECK: irdl.operation @identity {
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands()
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
@ -33,9 +33,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> {
}
// CHECK: irdl.operation @mul {
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: %2 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands(%0, %1)
// CHECK-NEXT: irdl.results(%2)
// CHECK-NEXT: }
@ -45,8 +45,8 @@ def CMath_MulOp : CMath_Op<"mul"> {
}
// CHECK: irdl.operation @norm {
// CHECK-NEXT: %0 = irdl.c_pred "(true)"
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.any
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands(%0)
// CHECK-NEXT: irdl.results(%1)
// CHECK-NEXT: }

View File

@ -0,0 +1,74 @@
// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=test | FileCheck %s
include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"
// CHECK-LABEL: irdl.dialect @test {
def Test_Dialect : Dialect {
let name = "test";
}
class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Test_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}
class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;
def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
// Check that AllOfType is converted correctly.
def Test_AndOp : Test_Op<"and"> {
let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
}
// CHECK-LABEL: irdl.operation @and {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: irdl.operands(%[[v2]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
// Check that AnyType is converted correctly.
def Test_AnyOp : Test_Op<"any"> {
let arguments = (ins AnyType:$in);
}
// CHECK-LABEL: irdl.operation @any {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
// CHECK-NEXT: irdl.operands(%[[v0]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
// Check that AnyTypeOf is converted correctly.
def Test_OrOp : Test_Op<"or"> {
let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
}
// CHECK-LABEL: irdl.operation @or {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.operands(%[[v3]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
// Check that variadics and optionals are converted correctly.
def Test_VariadicityOp : Test_Op<"variadicity"> {
let arguments = (ins Variadic<Test_SingletonAType>:$variadic,
Optional<Test_SingletonBType>:$optional,
Test_SingletonCType:$required);
}
// CHECK-LABEL: irdl.operation @variadicity {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }

View File

@ -39,15 +39,49 @@ llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
irdl::CPredOp createConstraint(OpBuilder &builder,
NamedTypeConstraint namedConstraint) {
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
// Build the constraint as a string.
std::string constraint =
namedConstraint.constraint.getPredicate().getCondition();
const Record &predRec = constraint.getDef();
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
return createConstraint(builder, predRec.getValueAsDef("baseType"));
if (predRec.getName() == "AnyType") {
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
return op.getOutput();
}
if (predRec.isSubClassOf("TypeDef")) {
std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
StringAttr::get(ctx, typeName));
return op.getOutput();
}
if (predRec.isSubClassOf("AnyTypeOf")) {
std::vector<Value> constraints;
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
if (predRec.isSubClassOf("AllOfType")) {
std::vector<Value> constraints;
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
std::string condition = constraint.getPredicate().getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = builder.create<irdl::CPredOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
return op;
}
@ -74,7 +108,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
auto operand = createConstraint(consBuilder, namedCons);
auto operand = createConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);
irdl::VariadicityAttr var;