[MLIR] convert OpAsmDialectInterface using ODS (#171488)

This PR converts OpAsmDialectInterface using ODS.

It also introduces a new Interface Method class `InterfaceMethodDeclaration` which will declare the function without definition.
This commit is contained in:
AidinT 2026-01-29 18:41:34 +01:00 committed by GitHub
parent f3ecf490a4
commit caae29c4b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 181 additions and 93 deletions

View File

@ -60,6 +60,10 @@ mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
add_mlir_generic_tablegen_target(MLIRTensorEncodingIncGen)
set(LLVM_TARGET_DEFINITIONS OpAsmDialectInterface.td)
mlir_tablegen(OpAsmDialectInterface.h.inc -gen-dialect-interface-decls)
add_mlir_generic_tablegen_target(MLIROpAsmDialectInterfaceIncGen)
add_mlir_doc(BuiltinAttributes BuiltinAttributes Dialects/ -gen-attrdef-doc)
add_mlir_doc(BuiltinLocationAttributes BuiltinLocationAttributes Dialects/ -gen-attrdef-doc)
add_mlir_doc(BuiltinOps BuiltinOps Dialects/ -gen-op-doc)

View File

@ -85,6 +85,11 @@ class StaticInterfaceMethod<string desc, string retTy, string methodName,
: InterfaceMethod<desc, retTy, methodName, args, methodBody,
defaultImplementation>;
// This class represents a interface method declaration.
class InterfaceMethodDeclaration<string desc, string retTy, string methodName,
dag args = (ins)>
: InterfaceMethod<desc, retTy, methodName, args>;
// Interface represents a base interface.
class Interface<string name, list<Interface> baseInterfacesArg = []> {
// A human-readable description of what this interface does.

View File

@ -0,0 +1,80 @@
#ifndef MLIR_INTERFACES_OPASMDIALECTINTERFACE
#define MLIR_INTERFACES_OPASMDIALECTINTERFACE
include "mlir/IR/Interfaces.td"
def OpAsmDialectInterface : DialectInterface<"OpAsmDialectInterface"> {
let description = [{
Dialect OpAsm interface
}];
let cppNamespace = "::mlir";
let extraClassDeclaration = [{
using AliasResult = OpAsmAliasResult;
}];
let methods = [
InterfaceMethod<[{
Hooks for getting an alias identifier alias for a given symbol, that is
not necessarily a part of this dialect. The identifier is used in place of
the symbol when printing textual IR. These aliases must not contain `.` or
end with a numeric digit ([0-9]+).
}],
"OpAsmAliasResult", "getAlias",
(ins "::mlir::Attribute":$attr, "::llvm::raw_ostream &":$os),
[{
return OpAsmAliasResult::NoAlias;
}]
>,
InterfaceMethod<[{}], "OpAsmAliasResult", "getAlias",
(ins "::mlir::Type":$type, "::llvm::raw_ostream &":$os),
[{
return OpAsmAliasResult::NoAlias;
}]
>,
InterfaceMethod<[{
Declare a resource with the given key, returning a handle to use for any
references of this resource key within the IR during parsing. The result
of `getResourceKey` on the returned handle is permitted to be different
than `key`.
}],
"::mlir::FailureOr<::mlir::AsmDialectResourceHandle>", "declareResource",
(ins "::mlir::StringRef":$key),
[{
return failure();
}]
>,
InterfaceMethod<[{
Return a key to use for the given resource. This key should uniquely
identify this resource within the dialect.
}],
"std::string", "getResourceKey",
(ins "const ::mlir::AsmDialectResourceHandle &":$handle),
[{
llvm_unreachable(
"Dialect must implement `getResourceKey` when defining resources");
}]
>,
InterfaceMethodDeclaration<[{
Hook for parsing resource entries. Returns failure if the entry was not
valid, or could otherwise not be processed correctly. Any necessary errors
can be emitted via the provided entry.
}],
"::llvm::LogicalResult", "parseResource",
(ins "::mlir::AsmParsedResourceEntry &":$entry)
>,
InterfaceMethod<[{
Hook for building resources to use during printing. The given `op` may be
inspected to help determine what information to include.
`referencedResources` contains all of the resources detected when printing
'op'.
}],
"void", "buildResources",
(ins "::mlir::Operation *":$op,
"const ::mlir::SetVector<::mlir::AsmDialectResourceHandle> &":$referencedResources,
"::mlir::AsmResourceBuilder &":$builder)
>
];
}
#endif

View File

@ -1779,64 +1779,6 @@ public:
SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
};
//===--------------------------------------------------------------------===//
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//
class OpAsmDialectInterface
: public DialectInterface::Base<OpAsmDialectInterface> {
public:
OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
using AliasResult = OpAsmAliasResult;
/// Hooks for getting an alias identifier alias for a given symbol, that is
/// not necessarily a part of this dialect. The identifier is used in place of
/// the symbol when printing textual IR. These aliases must not contain `.` or
/// end with a numeric digit([0-9]+).
virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
return AliasResult::NoAlias;
}
virtual AliasResult getAlias(Type type, raw_ostream &os) const {
return AliasResult::NoAlias;
}
//===--------------------------------------------------------------------===//
// Resources
//===--------------------------------------------------------------------===//
/// Declare a resource with the given key, returning a handle to use for any
/// references of this resource key within the IR during parsing. The result
/// of `getResourceKey` on the returned handle is permitted to be different
/// than `key`.
virtual FailureOr<AsmDialectResourceHandle>
declareResource(StringRef key) const {
return failure();
}
/// Return a key to use for the given resource. This key should uniquely
/// identify this resource within the dialect.
virtual std::string
getResourceKey(const AsmDialectResourceHandle &handle) const {
llvm_unreachable(
"Dialect must implement `getResourceKey` when defining resources");
}
/// Hook for parsing resource entries. Returns failure if the entry was not
/// valid, or could otherwise not be processed correctly. Any necessary errors
/// can be emitted via the provided entry.
virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
/// Hook for building resources to use during printing. The given `op` may be
/// inspected to help determine what information to include.
/// `referencedResources` contains all of the resources detected when printing
/// 'op'.
virtual void
buildResources(Operation *op,
const SetVector<AsmDialectResourceHandle> &referencedResources,
AsmResourceBuilder &builder) const {}
};
//===--------------------------------------------------------------------===//
// Custom printers and parsers.
//===--------------------------------------------------------------------===//
@ -1856,6 +1798,13 @@ ParseResult parseDimensionList(OpAsmParser &parser,
/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
#include "mlir/IR/OpAsmOpInterface.h.inc"
//===--------------------------------------------------------------------===//
// Dialect OpAsm interface.
//===--------------------------------------------------------------------===//
/// The OpAsmDialectInterface, see OpAsmDialectInterface.td
#include "mlir/IR/OpAsmDialectInterface.h.inc"
namespace llvm {
template <>
struct DenseMapInfo<mlir::AsmDialectResourceHandle> {

View File

@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/iterator.h"
@ -46,6 +47,9 @@ public:
// Return if this method is static.
bool isStatic() const;
// Return if the method is only a declaration.
bool isDeclaration() const;
// Return the body for this method if it has one.
std::optional<StringRef> getBody() const;

View File

@ -67,6 +67,7 @@ add_mlir_library(MLIRIR
MLIRSideEffectInterfacesIncGen
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
MLIROpAsmDialectInterfaceIncGen
LINK_LIBS PUBLIC
MLIRSupport

View File

@ -11,6 +11,7 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include <utility>
using namespace mlir;
using namespace mlir::tblgen;
@ -51,6 +52,11 @@ bool InterfaceMethod::isStatic() const {
return def->isSubClassOf("StaticInterfaceMethod");
}
// Return if the method is only a declaration.
bool InterfaceMethod::isDeclaration() const {
return def->isSubClassOf("InterfaceMethodDeclaration");
}
// Return the body for this method if it has one.
std::optional<StringRef> InterfaceMethod::getBody() const {
// Trim leading and trailing spaces from the default implementation.

View File

@ -11,24 +11,18 @@ def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
let methods = [
InterfaceMethod<
/*desc=*/ "Check if it's an example dialect",
/*returnType=*/ "bool",
/*methodName=*/ "isExampleDialect",
/*args=*/ (ins)
>,
InterfaceMethod<
/*desc=*/ "second method to check if multiple methods supported",
/*returnType=*/ "unsigned",
/*methodName=*/ "supportSecondMethod",
/*args=*/ (ins "::mlir::Type":$type)
"Check if it's an example dialect", "bool", "isExampleDialect", (ins)
>,
InterfaceMethod<
"second method to check if multiple methods supported",
"unsigned", "supportSecondMethod", (ins "::mlir::Type":$type)
>
];
}
// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
// DECL: public:
// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
// DECL: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
// DECL: virtual bool isExampleDialect() const {}
// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const {}
@ -40,26 +34,42 @@ def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface">
let cppNamespace = "::mlir::example";
let methods = [
InterfaceMethod<
/*desc=*/ "Check if it's an example dialect",
/*returnType=*/ "bool",
/*methodName=*/ "isExampleDialect",
/*args=*/ (ins),
/*methodBody=*/ [{
InterfaceMethod<
"Check if it's an example dialect", "bool", "isExampleDialect", (ins),
[{
return true;
}]
}]
>,
InterfaceMethod<
/*desc=*/ "second method to check if multiple methods supported",
/*returnType=*/ "unsigned",
/*methodName=*/ "supportSecondMethod",
/*args=*/ (ins "::mlir::Type":$type)
"second method to check if multiple methods supported",
"unsigned", "supportSecondMethod", (ins "::mlir::Type":$type)
>
];
}
// DECL: virtual bool isExampleDialect() const {
// DECL-NEXT: return true;
// DECL-NEXT: }
// DECL: virtual bool isExampleDialect() const {
// DECL-NEXT: return true;
// DECL-NEXT: }
def OnlyDeclarationInterfaceWithExtraDecls : DialectInterface<"OnlyDeclarationInterfaceWithExtraDecls"> {
let description = [{
This is an example dialect interface with only declarations.
}];
let cppNamespace = "::mlir::example";
let methods = [
InterfaceMethodDeclaration<
"a method declaration", "void", "exampleMethodDeclaration",
(ins "::mlir::Type":$type)
>
];
let extraClassDeclaration = [{
using DeclType = int;
}];
}
// DECL: class OnlyDeclarationInterfaceWithExtraDecls : public {{.*}}DialectInterface::Base<OnlyDeclarationInterfaceWithExtraDecls>
// DECL: virtual void exampleMethodDeclaration(::mlir::Type type) const;
// DECL: using DeclType = int;

View File

@ -15,6 +15,7 @@
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
@ -26,7 +27,7 @@
using namespace mlir;
using llvm::Record;
using llvm::RecordKeeper;
using mlir::tblgen::Interface;
using mlir::tblgen::DialectInterface;
using mlir::tblgen::InterfaceMethod;
/// Emit a string corresponding to a C++ type, followed by a space if necessary.
@ -74,7 +75,7 @@ public:
bool emitInterfaceDecls();
protected:
void emitInterfaceDecl(const Interface &interface);
void emitInterfaceDecl(const DialectInterface &interface);
/// The set of interface records to emit.
std::vector<const Record *> defs;
@ -91,9 +92,11 @@ static void emitInterfaceMethodDoc(const InterfaceMethod &method,
raw_ostream &os, StringRef prefix = "") {
if (std::optional<StringRef> description = method.getDescription())
tblgen::emitDescriptionComment(*description, os, prefix);
else
os << "\n";
}
static void emitInterfaceMethodsDef(const Interface &interface,
static void emitInterfaceMethodsDef(const DialectInterface &interface,
raw_ostream &os) {
raw_indented_ostream ios(os);
@ -104,6 +107,13 @@ static void emitInterfaceMethodsDef(const Interface &interface,
ios << "virtual ";
emitCPPType(method.getReturnType(), ios);
emitMethodNameAndArgs(method, method.getName(), ios);
if (method.isDeclaration()) {
ios << ";\n";
continue;
}
// if it is not a method declaration, then it's a normal interface method.
ios << " {";
if (auto body = method.getBody()) {
@ -116,11 +126,10 @@ static void emitInterfaceMethodsDef(const Interface &interface,
}
}
void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
void DialectInterfaceGenerator::emitInterfaceDecl(
const DialectInterface &interface) {
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
StringRef interfaceName = interface.getName();
tblgen::emitSummaryAndDescComments(os, "",
interface.getDescription().value_or(""));
@ -129,10 +138,19 @@ void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
"class {0} : public ::mlir::DialectInterface::Base<{0}> {\n"
"public:\n"
" {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
interfaceName);
interface.getName());
emitInterfaceMethodsDef(interface, os);
// Emit any extra declarations.
if (std::optional<StringRef> extraDecls =
interface.getExtraClassDeclaration()) {
raw_indented_ostream ios(os);
ios.indent(2);
ios.printReindented(extraDecls.value());
ios << "\n";
}
os << "};\n";
}
@ -148,7 +166,7 @@ bool DialectInterfaceGenerator::emitInterfaceDecls() {
});
for (const Record *def : sortedDefs)
emitInterfaceDecl(Interface(def));
emitInterfaceDecl(DialectInterface(def));
return false;
}

View File

@ -113,6 +113,16 @@ gentbl_cc_library(
deps = [":OpBaseTdFiles"],
)
gentbl_cc_library(
name = "OpAsmDialectInterfaceIncGen",
tbl_outs = {
"include/mlir/IR/OpAsmDialectInterface.h.inc": ["-gen-dialect-interface-decls"],
},
tblgen = ":mlir-tblgen",
td_file = "include/mlir/IR/OpAsmDialectInterface.td",
deps = [":OpBaseTdFiles"],
)
gentbl_cc_library(
name = "TensorEncodingIncGen",
tbl_outs = {
@ -397,6 +407,7 @@ cc_library(
"lib/Bytecode/Writer/*.h",
"include/mlir/Bytecode/*.h",
]) + [
"include/mlir/IR/OpAsmDialectInterface.h.inc",
"include/mlir/IR/OpAsmOpInterface.h.inc",
"include/mlir/Interfaces/DataLayoutInterfaces.h",
"include/mlir/Interfaces/InferIntRangeInterface.h",