[NFC][MLIR][TableGen] Eliminate llvm:: for commonly used types (#110841)

Eliminate `llvm::` namespace qualifier for commonly used types in MLIR
TableGen backends to reduce code clutter.
This commit is contained in:
Rahul Joshi 2024-10-02 13:23:44 -07:00 committed by GitHub
parent 906ffc4b4a
commit bccd37f69f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 450 additions and 475 deletions

View File

@ -22,6 +22,8 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::Record;
using llvm::RecordKeeper;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Utility Functions // Utility Functions
@ -30,14 +32,14 @@ using namespace mlir::tblgen;
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect /// Find all the AttrOrTypeDef for the specified dialect. If no dialect
/// specified and can only find one dialect's defs, use that. /// specified and can only find one dialect's defs, use that.
static void collectAllDefs(StringRef selectedDialect, static void collectAllDefs(StringRef selectedDialect,
ArrayRef<const llvm::Record *> records, ArrayRef<const Record *> records,
SmallVectorImpl<AttrOrTypeDef> &resultDefs) { SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
// Nothing to do if no defs were found. // Nothing to do if no defs were found.
if (records.empty()) if (records.empty())
return; return;
auto defs = llvm::map_range( auto defs = llvm::map_range(
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); }); records, [&](const Record *rec) { return AttrOrTypeDef(rec); });
if (selectedDialect.empty()) { if (selectedDialect.empty()) {
// If a dialect was not specified, ensure that all found defs belong to the // If a dialect was not specified, ensure that all found defs belong to the
// same dialect. // same dialect.
@ -690,13 +692,12 @@ public:
bool emitDefs(StringRef selectedDialect); bool emitDefs(StringRef selectedDialect);
protected: protected:
DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os, DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
StringRef defType, StringRef valueType, bool isAttrGenerator) StringRef defType, StringRef valueType, bool isAttrGenerator)
: defRecords(defs), os(os), defType(defType), valueType(valueType), : defRecords(defs), os(os), defType(defType), valueType(valueType),
isAttrGenerator(isAttrGenerator) { isAttrGenerator(isAttrGenerator) {
// Sort by occurrence in file. // Sort by occurrence in file.
llvm::sort(defRecords, llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
[](const llvm::Record *lhs, const llvm::Record *rhs) {
return lhs->getID() < rhs->getID(); return lhs->getID() < rhs->getID();
}); });
} }
@ -707,7 +708,7 @@ protected:
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs); void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
/// The set of def records to emit. /// The set of def records to emit.
std::vector<const llvm::Record *> defRecords; std::vector<const Record *> defRecords;
/// The attribute or type class to emit. /// The attribute or type class to emit.
/// The stream to emit to. /// The stream to emit to.
raw_ostream &os; raw_ostream &os;
@ -722,13 +723,13 @@ protected:
/// A specialized generator for AttrDefs. /// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator { struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os, : DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
"Attr", "Attribute", /*isAttrGenerator=*/true) {} "Attr", "Attribute", /*isAttrGenerator=*/true) {}
}; };
/// A specialized generator for TypeDefs. /// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator { struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os, : DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
"Type", "Type", /*isAttrGenerator=*/false) {} "Type", "Type", /*isAttrGenerator=*/false) {}
}; };
@ -1030,9 +1031,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
/// Find all type constraints for which a C++ function should be generated. /// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint> static std::vector<Constraint>
getAllTypeConstraints(const llvm::RecordKeeper &records) { getAllTypeConstraints(const RecordKeeper &records) {
std::vector<Constraint> result; std::vector<Constraint> result;
for (const llvm::Record *def : for (const Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) { records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
// Ignore constraints defined outside of the top-level file. // Ignore constraints defined outside of the top-level file.
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
@ -1047,7 +1048,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
return result; return result;
} }
static void emitTypeConstraintDecls(const llvm::RecordKeeper &records, static void emitTypeConstraintDecls(const RecordKeeper &records,
raw_ostream &os) { raw_ostream &os) {
static const char *const typeConstraintDecl = R"( static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type); bool {0}(::mlir::Type type);
@ -1057,7 +1058,7 @@ bool {0}(::mlir::Type type);
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName()); os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
} }
static void emitTypeConstraintDefs(const llvm::RecordKeeper &records, static void emitTypeConstraintDefs(const RecordKeeper &records,
raw_ostream &os) { raw_ostream &os) {
static const char *const typeConstraintDef = R"( static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) { bool {0}(::mlir::Type type) {
@ -1088,13 +1089,13 @@ static llvm::cl::opt<std::string>
static mlir::GenRegistration static mlir::GenRegistration
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions", genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os); AttrDefGenerator generator(records, os);
return generator.emitDefs(attrDialect); return generator.emitDefs(attrDialect);
}); });
static mlir::GenRegistration static mlir::GenRegistration
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations", genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os); AttrDefGenerator generator(records, os);
return generator.emitDecls(attrDialect); return generator.emitDecls(attrDialect);
}); });
@ -1110,13 +1111,13 @@ static llvm::cl::opt<std::string>
static mlir::GenRegistration static mlir::GenRegistration
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions", genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os); TypeDefGenerator generator(records, os);
return generator.emitDefs(typeDialect); return generator.emitDefs(typeDialect);
}); });
static mlir::GenRegistration static mlir::GenRegistration
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations", genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os); TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect); return generator.emitDecls(typeDialect);
}); });
@ -1124,14 +1125,14 @@ static mlir::GenRegistration
static mlir::GenRegistration static mlir::GenRegistration
genTypeConstrDefs("gen-type-constraint-defs", genTypeConstrDefs("gen-type-constraint-defs",
"Generate type constraint definitions", "Generate type constraint definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDefs(records, os); emitTypeConstraintDefs(records, os);
return false; return false;
}); });
static mlir::GenRegistration static mlir::GenRegistration
genTypeConstrDecls("gen-type-constraint-decls", genTypeConstrDecls("gen-type-constraint-decls",
"Generate type constraint declarations", "Generate type constraint declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDecls(records, os); emitTypeConstraintDecls(records, os);
return false; return false;
}); });

View File

@ -18,11 +18,10 @@
using namespace llvm; using namespace llvm;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
static llvm::cl::opt<std::string> static cl::opt<std::string>
selectedBcDialect("bytecode-dialect", selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
llvm::cl::desc("The dialect to gen for"), cl::cat(dialectGenCat), cl::CommaSeparated);
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
namespace { namespace {
@ -306,7 +305,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
auto funScope = os.scope("{\n", "}\n\n"); auto funScope = os.scope("{\n", "}\n\n");
// Check that predicates specified if multiple bytecode instances. // Check that predicates specified if multiple bytecode instances.
for (const llvm::Record *rec : make_second_range(vec)) { for (const Record *rec : make_second_range(vec)) {
StringRef pred = rec->getValueAsString("printerPredicate"); StringRef pred = rec->getValueAsString("printerPredicate");
if (vec.size() > 1 && pred.empty()) { if (vec.size() > 1 && pred.empty()) {
for (auto [index, rec] : vec) { for (auto [index, rec] : vec) {

View File

@ -30,6 +30,8 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::Record;
using llvm::RecordKeeper;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*"); static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
llvm::cl::opt<std::string> llvm::cl::opt<std::string>
@ -39,8 +41,8 @@ llvm::cl::opt<std::string>
/// Utility iterator used for filtering records for a specific dialect. /// Utility iterator used for filtering records for a specific dialect.
namespace { namespace {
using DialectFilterIterator = using DialectFilterIterator =
llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator, llvm::filter_iterator<ArrayRef<Record *>::iterator,
std::function<bool(const llvm::Record *)>>; std::function<bool(const Record *)>>;
} // namespace } // namespace
static void populateDiscardableAttributes( static void populateDiscardableAttributes(
@ -62,8 +64,8 @@ static void populateDiscardableAttributes(
/// the given dialect. /// the given dialect.
template <typename T> template <typename T>
static iterator_range<DialectFilterIterator> static iterator_range<DialectFilterIterator>
filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) { filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
auto filterFn = [&](const llvm::Record *record) { auto filterFn = [&](const Record *record) {
return T(record).getDialect() == dialect; return T(record).getDialect() == dialect;
}; };
return {DialectFilterIterator(records.begin(), records.end(), filterFn), return {DialectFilterIterator(records.begin(), records.end(), filterFn),
@ -295,7 +297,7 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
<< "::" << dialect.getCppClassName() << ")\n"; << "::" << dialect.getCppClassName() << ")\n";
} }
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper, static bool emitDialectDecls(const RecordKeeper &recordKeeper,
raw_ostream &os) { raw_ostream &os) {
emitSourceFileHeader("Dialect Declarations", os, recordKeeper); emitSourceFileHeader("Dialect Declarations", os, recordKeeper);
@ -340,8 +342,7 @@ static const char *const dialectDestructorStr = R"(
)"; )";
static void emitDialectDef(Dialect &dialect, static void emitDialectDef(Dialect &dialect, const RecordKeeper &recordKeeper,
const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) { raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName(); std::string cppClassName = dialect.getCppClassName();
@ -389,8 +390,7 @@ static void emitDialectDef(Dialect &dialect,
os << llvm::formatv(dialectDestructorStr, cppClassName); os << llvm::formatv(dialectDestructorStr, cppClassName);
} }
static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper, static bool emitDialectDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
raw_ostream &os) {
emitSourceFileHeader("Dialect Definitions", os, recordKeeper); emitSourceFileHeader("Dialect Definitions", os, recordKeeper);
auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect"); auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
@ -411,12 +411,12 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
static mlir::GenRegistration static mlir::GenRegistration
genDialectDecls("gen-dialect-decls", "Generate dialect declarations", genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
return emitDialectDecls(records, os); return emitDialectDecls(records, os);
}); });
static mlir::GenRegistration static mlir::GenRegistration
genDialectDefs("gen-dialect-defs", "Generate dialect definitions", genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
return emitDialectDefs(records, os); return emitDialectDefs(records, os);
}); });

View File

@ -21,6 +21,9 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
/// File header and includes. /// File header and includes.
constexpr const char *fileHeader = R"Py( constexpr const char *fileHeader = R"Py(
@ -42,15 +45,15 @@ static std::string makePythonEnumCaseName(StringRef name) {
/// Emits the Python class for the given enum. /// Emits the Python class for the given enum.
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) { static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(), os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum"); enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
if (!enumAttr.getSummary().empty()) if (!enumAttr.getSummary().empty())
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary()); os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
os << "\n"; os << "\n";
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
os << llvm::formatv( os << formatv(" {0} = {1}\n",
" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()), makePythonEnumCaseName(enumCase.getSymbol()),
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
: "auto()"); : "auto()");
} }
@ -58,27 +61,25 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
os << "\n"; os << "\n";
if (enumAttr.isBitEnum()) { if (enumAttr.isBitEnum()) {
os << llvm::formatv(" def __iter__(self):\n" os << formatv(" def __iter__(self):\n"
" return iter([case for case in type(self) if " " return iter([case for case in type(self) if "
"(self & case) is case])\n"); "(self & case) is case])\n");
os << llvm::formatv(" def __len__(self):\n" os << formatv(" def __len__(self):\n"
" return bin(self).count(\"1\")\n"); " return bin(self).count(\"1\")\n");
os << "\n"; os << "\n";
} }
os << llvm::formatv(" def __str__(self):\n"); os << formatv(" def __str__(self):\n");
if (enumAttr.isBitEnum()) if (enumAttr.isBitEnum())
os << llvm::formatv(" if len(self) > 1:\n" os << formatv(" if len(self) > 1:\n"
" return \"{0}\".join(map(str, self))\n", " return \"{0}\".join(map(str, self))\n",
enumAttr.getDef().getValueAsString("separator")); enumAttr.getDef().getValueAsString("separator"));
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
os << llvm::formatv(" if self is {0}.{1}:\n", os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
enumAttr.getEnumClassName(),
makePythonEnumCaseName(enumCase.getSymbol())); makePythonEnumCaseName(enumCase.getSymbol()));
os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr()); os << formatv(" return \"{0}\"\n", enumCase.getStr());
} }
os << llvm::formatv( os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
enumAttr.getEnumClassName()); enumAttr.getEnumClassName());
os << "\n"; os << "\n";
} }
@ -105,12 +106,10 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true; return true;
} }
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", os << formatv("@register_attribute_builder(\"{0}\")\n",
enumAttr.getAttrDefName()); enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n", os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
enumAttr.getAttrDefName().lower()); os << formatv(" return "
os << llvm::formatv(
" return "
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
"context=context), int(x))\n\n", "context=context), int(x))\n\n",
bitwidth); bitwidth);
@ -123,9 +122,9 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
StringRef formatString, StringRef formatString,
raw_ostream &os) { raw_ostream &os) {
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower()); os << formatv("def _{0}(x, context):\n", attrDefName.lower());
os << llvm::formatv(" return " os << formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString); formatString);
return false; return false;
@ -133,16 +132,15 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
/// Emits Python bindings for all enums in the record keeper. Returns /// Emits Python bindings for all enums in the record keeper. Returns
/// `false` on success, `true` on failure. /// `false` on success, `true` on failure.
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper, static bool emitPythonEnums(const RecordKeeper &recordKeeper, raw_ostream &os) {
raw_ostream &os) {
os << fileHeader; os << fileHeader;
for (const llvm::Record *it : for (const Record *it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) { recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
EnumAttr enumAttr(*it); EnumAttr enumAttr(*it);
emitEnumClass(enumAttr, os); emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os); emitAttributeBuilder(enumAttr, os);
} }
for (const llvm::Record *it : for (const Record *it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) { recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it); AttrOrTypeDef attr(&*it);
if (!attr.getMnemonic()) { if (!attr.getMnemonic()) {
@ -156,11 +154,11 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
if (assemblyFormat == "`<` $value `>`") { if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder( emitDialectEnumAttributeBuilder(
attr.getName(), attr.getName(),
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os); formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
} else if (assemblyFormat == "$value") { } else if (assemblyFormat == "$value") {
emitDialectEnumAttributeBuilder( emitDialectEnumAttributeBuilder(
attr.getName(), attr.getName(),
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os); formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
} else { } else {
llvm::errs() llvm::errs()
<< "unsupported assembly format for python enum bindings generation"; << "unsupported assembly format for python enum bindings generation";

View File

@ -26,10 +26,9 @@
using llvm::formatv; using llvm::formatv;
using llvm::isDigit; using llvm::isDigit;
using llvm::PrintFatalError; using llvm::PrintFatalError;
using llvm::raw_ostream;
using llvm::Record; using llvm::Record;
using llvm::RecordKeeper; using llvm::RecordKeeper;
using llvm::StringRef; using namespace mlir;
using mlir::tblgen::Attribute; using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr; using mlir::tblgen::EnumAttr;
using mlir::tblgen::EnumAttrCase; using mlir::tblgen::EnumAttrCase;
@ -139,7 +138,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
// is not a power of two (i.e. not a single bit case) and not a known case. // is not a power of two (i.e. not a single bit case) and not a known case.
} else if (enumAttr.isBitEnum()) { } else if (enumAttr.isBitEnum()) {
// Process the known multi-bit cases that use valid keywords. // Process the known multi-bit cases that use valid keywords.
llvm::SmallVector<EnumAttrCase *> validMultiBitCases; SmallVector<EnumAttrCase *> validMultiBitCases;
for (auto [index, caseVal] : llvm::enumerate(cases)) { for (auto [index, caseVal] : llvm::enumerate(cases)) {
uint64_t value = caseVal.getValue(); uint64_t value = caseVal.getValue();
if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index)) if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
@ -476,7 +475,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef); EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName(); StringRef enumName = enumAttr.getEnumClassName();
StringRef attrClassName = enumAttr.getSpecializedAttrClassName(); StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
const llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass(); const Record *baseAttrDef = enumAttr.getBaseAttrClass();
Attribute baseAttr(baseAttrDef); Attribute baseAttr(baseAttrDef);
// Emit classof method // Emit classof method
@ -565,7 +564,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName(); StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases(); auto enumerants = enumAttr.getAllCases();
llvm::SmallVector<StringRef, 2> namespaces; SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::"); llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces) for (auto ns : namespaces)
@ -656,7 +655,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef); EnumAttr enumAttr(enumDef);
StringRef cppNamespace = enumAttr.getCppNamespace(); StringRef cppNamespace = enumAttr.getCppNamespace();
llvm::SmallVector<StringRef, 2> namespaces; SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::"); llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces) for (auto ns : namespaces)

View File

@ -13,6 +13,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::SourceMgr;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// FormatToken // FormatToken
@ -26,14 +27,14 @@ SMLoc FormatToken::getLoc() const {
// FormatLexer // FormatLexer
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc) FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc)
: mgr(mgr), loc(loc), : mgr(mgr), loc(loc),
curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()), curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
curPtr(curBuffer.begin()) {} curPtr(curBuffer.begin()) {}
FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) { FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note, llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
"in custom assembly format for this operation"); "in custom assembly format for this operation");
return formToken(FormatToken::error, loc.getPointer()); return formToken(FormatToken::error, loc.getPointer());
} }
@ -44,10 +45,10 @@ FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg, FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
const Twine &note) { const Twine &note) {
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note, llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
"in custom assembly format for this operation"); "in custom assembly format for this operation");
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note); mgr.PrintMessage(loc, SourceMgr::DK_Note, note);
return formToken(FormatToken::error, loc.getPointer()); return formToken(FormatToken::error, loc.getPointer());
} }

View File

@ -411,8 +411,7 @@ public:
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API enumerant // (Enum) to the corresponding LLVM API enumerant
static void emitOneEnumToConversion(const llvm::Record *record, static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
raw_ostream &os) {
LLVMEnumAttr enumAttr(record); LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName(); StringRef cppClassName = enumAttr.getEnumClassName();
@ -441,8 +440,7 @@ static void emitOneEnumToConversion(const llvm::Record *record,
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API C-style enumerant // (Enum) to the corresponding LLVM API C-style enumerant
static void emitOneCEnumToConversion(const llvm::Record *record, static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
raw_ostream &os) {
LLVMCEnumAttr enumAttr(record); LLVMCEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName(); StringRef cppClassName = enumAttr.getEnumClassName();
@ -472,8 +470,7 @@ static void emitOneCEnumToConversion(const llvm::Record *record,
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and // Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
// containing switch-based logic to convert from the LLVM API enumerant to MLIR // containing switch-based logic to convert from the LLVM API enumerant to MLIR
// LLVM dialect enum attribute (Enum). // LLVM dialect enum attribute (Enum).
static void emitOneEnumFromConversion(const llvm::Record *record, static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
raw_ostream &os) {
LLVMEnumAttr enumAttr(record); LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName(); StringRef cppClassName = enumAttr.getEnumClassName();
@ -508,8 +505,7 @@ static void emitOneEnumFromConversion(const llvm::Record *record,
// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and // Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
// containing switch-based logic to convert from the LLVM API C-style enumerant // containing switch-based logic to convert from the LLVM API C-style enumerant
// to MLIR LLVM dialect enum attribute (Enum). // to MLIR LLVM dialect enum attribute (Enum).
static void emitOneCEnumFromConversion(const llvm::Record *record, static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
raw_ostream &os) {
LLVMCEnumAttr enumAttr(record); LLVMCEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName(); StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName(); StringRef cppClassName = enumAttr.getEnumClassName();

View File

@ -24,6 +24,11 @@
#include "llvm/TableGen/Record.h" #include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h" #include "llvm/TableGen/TableGenBackend.h"
using llvm::Record;
using llvm::RecordKeeper;
using llvm::Regex;
using namespace mlir;
static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options"); static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options");
static llvm::cl::opt<std::string> static llvm::cl::opt<std::string>
@ -54,14 +59,14 @@ static llvm::cl::opt<std::string> aliasAnalysisRegexp(
using IndicesTy = llvm::SmallBitVector; using IndicesTy = llvm::SmallBitVector;
/// Return a CodeGen value type entry from a type record. /// Return a CodeGen value type entry from a type record.
static llvm::MVT::SimpleValueType getValueType(const llvm::Record *rec) { static llvm::MVT::SimpleValueType getValueType(const Record *rec) {
return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt( return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
"Value"); "Value");
} }
/// Return the indices of the definitions in a list of definitions that /// Return the indices of the definitions in a list of definitions that
/// represent overloadable types /// represent overloadable types
static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record, static IndicesTy getOverloadableTypeIdxs(const Record &record,
const char *listName) { const char *listName) {
auto results = record.getValueAsListOfDefs(listName); auto results = record.getValueAsListOfDefs(listName);
IndicesTy overloadedOps(results.size()); IndicesTy overloadedOps(results.size());
@ -87,13 +92,13 @@ namespace {
/// the fields of the record. /// the fields of the record.
class LLVMIntrinsic { class LLVMIntrinsic {
public: public:
LLVMIntrinsic(const llvm::Record &record) : record(record) {} LLVMIntrinsic(const Record &record) : record(record) {}
/// Get the name of the operation to be used in MLIR. Uses the appropriate /// Get the name of the operation to be used in MLIR. Uses the appropriate
/// field if not empty, constructs a name by replacing underscores with dots /// field if not empty, constructs a name by replacing underscores with dots
/// in the record name otherwise. /// in the record name otherwise.
std::string getOperationName() const { std::string getOperationName() const {
llvm::StringRef name = record.getValueAsString(fieldName); StringRef name = record.getValueAsString(fieldName);
if (!name.empty()) if (!name.empty())
return name.str(); return name.str();
@ -101,8 +106,8 @@ public:
assert(name.starts_with("int_") && assert(name.starts_with("int_") &&
"LLVM intrinsic names are expected to start with 'int_'"); "LLVM intrinsic names are expected to start with 'int_'");
name = name.drop_front(4); name = name.drop_front(4);
llvm::SmallVector<llvm::StringRef, 8> chunks; SmallVector<StringRef, 8> chunks;
llvm::StringRef targetPrefix = record.getValueAsString("TargetPrefix"); StringRef targetPrefix = record.getValueAsString("TargetPrefix");
name.split(chunks, '_'); name.split(chunks, '_');
auto *chunksBegin = chunks.begin(); auto *chunksBegin = chunks.begin();
// Remove the target prefix from target specific intrinsics. // Remove the target prefix from target specific intrinsics.
@ -119,8 +124,8 @@ public:
} }
/// Get the name of the record without the "intrinsic" prefix. /// Get the name of the record without the "intrinsic" prefix.
llvm::StringRef getProperRecordName() const { StringRef getProperRecordName() const {
llvm::StringRef name = record.getName(); StringRef name = record.getName();
assert(name.starts_with("int_") && assert(name.starts_with("int_") &&
"LLVM intrinsic names are expected to start with 'int_'"); "LLVM intrinsic names are expected to start with 'int_'");
return name.drop_front(4); return name.drop_front(4);
@ -129,10 +134,9 @@ public:
/// Get the number of operands. /// Get the number of operands.
unsigned getNumOperands() const { unsigned getNumOperands() const {
auto operands = record.getValueAsListOfDefs(fieldOperands); auto operands = record.getValueAsListOfDefs(fieldOperands);
assert(llvm::all_of(operands, assert(llvm::all_of(
[](const llvm::Record *r) { operands,
return r->isSubClassOf("LLVMType"); [](const Record *r) { return r->isSubClassOf("LLVMType"); }) &&
}) &&
"expected operands to be of LLVM type"); "expected operands to be of LLVM type");
return operands.size(); return operands.size();
} }
@ -142,7 +146,7 @@ public:
/// structure type. /// structure type.
unsigned getNumResults() const { unsigned getNumResults() const {
auto results = record.getValueAsListOfDefs(fieldResults); auto results = record.getValueAsListOfDefs(fieldResults);
for (const llvm::Record *r : results) { for (const Record *r : results) {
(void)r; (void)r;
assert(r->isSubClassOf("LLVMType") && assert(r->isSubClassOf("LLVMType") &&
"expected operands to be of LLVM type"); "expected operands to be of LLVM type");
@ -155,7 +159,7 @@ public:
bool hasSideEffects() const { bool hasSideEffects() const {
return llvm::none_of( return llvm::none_of(
record.getValueAsListOfDefs(fieldTraits), record.getValueAsListOfDefs(fieldTraits),
[](const llvm::Record *r) { return r->getName() == "IntrNoMem"; }); [](const Record *r) { return r->getName() == "IntrNoMem"; });
} }
/// Return true if the intrinsic is commutative, i.e. has the respective /// Return true if the intrinsic is commutative, i.e. has the respective
@ -163,7 +167,7 @@ public:
bool isCommutative() const { bool isCommutative() const {
return llvm::any_of( return llvm::any_of(
record.getValueAsListOfDefs(fieldTraits), record.getValueAsListOfDefs(fieldTraits),
[](const llvm::Record *r) { return r->getName() == "Commutative"; }); [](const Record *r) { return r->getName() == "Commutative"; });
} }
IndicesTy getOverloadableOperandsIdxs() const { IndicesTy getOverloadableOperandsIdxs() const {
@ -181,7 +185,7 @@ private:
const char *fieldResults = "RetTypes"; const char *fieldResults = "RetTypes";
const char *fieldTraits = "IntrProperties"; const char *fieldTraits = "IntrProperties";
const llvm::Record &record; const Record &record;
}; };
} // namespace } // namespace
@ -195,27 +199,26 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
/// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic. /// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic.
/// Returns true on error, false on success. /// Returns true on error, false on success.
static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) { static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
LLVMIntrinsic intr(record); LLVMIntrinsic intr(record);
llvm::Regex accessGroupMatcher(accessGroupRegexp); Regex accessGroupMatcher(accessGroupRegexp);
bool requiresAccessGroup = bool requiresAccessGroup =
!accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName()); !accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
llvm::Regex aliasAnalysisMatcher(aliasAnalysisRegexp); Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() && bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() &&
aliasAnalysisMatcher.match(record.getName()); aliasAnalysisMatcher.match(record.getName());
// Prepare strings for traits, if any. // Prepare strings for traits, if any.
llvm::SmallVector<llvm::StringRef, 2> traits; SmallVector<StringRef, 2> traits;
if (intr.isCommutative()) if (intr.isCommutative())
traits.push_back("Commutative"); traits.push_back("Commutative");
if (!intr.hasSideEffects()) if (!intr.hasSideEffects())
traits.push_back("NoMemoryEffect"); traits.push_back("NoMemoryEffect");
// Prepare strings for operands. // Prepare strings for operands.
llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(), SmallVector<StringRef, 8> operands(intr.getNumOperands(), "LLVM_Type");
"LLVM_Type");
if (requiresAccessGroup) if (requiresAccessGroup)
operands.push_back( operands.push_back(
"OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups"); "OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups");
@ -247,14 +250,13 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
/// Traverses the list of TableGen definitions derived from the "Intrinsic" /// Traverses the list of TableGen definitions derived from the "Intrinsic"
/// class and generates MLIR ODS definitions for those intrinsics that have /// class and generates MLIR ODS definitions for those intrinsics that have
/// the name matching the filter. /// the name matching the filter.
static bool emitIntrinsics(const llvm::RecordKeeper &records, static bool emitIntrinsics(const RecordKeeper &records, llvm::raw_ostream &os) {
llvm::raw_ostream &os) {
llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records); llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records);
os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n"; os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n";
os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n"; os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n";
auto defs = records.getAllDerivedDefinitions("Intrinsic"); auto defs = records.getAllDerivedDefinitions("Intrinsic");
for (const llvm::Record *r : defs) { for (const Record *r : defs) {
if (!nameFilter.empty() && !r->getName().contains(nameFilter)) if (!nameFilter.empty() && !r->getName().contains(nameFilter))
continue; continue;
if (emitIntrinsic(*r, os)) if (emitIntrinsic(*r, os))

View File

@ -34,30 +34,30 @@
#include <set> #include <set>
#include <string> #include <string>
//===----------------------------------------------------------------------===//
// Commandline Options
//===----------------------------------------------------------------------===//
static llvm::cl::OptionCategory
docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
llvm::cl::opt<std::string>
stripPrefix("strip-prefix",
llvm::cl::desc("Strip prefix of the fully qualified names"),
llvm::cl::init("::mlir::"), llvm::cl::cat(docCat));
llvm::cl::opt<bool> allowHugoSpecificFeatures(
"allow-hugo-specific-features",
llvm::cl::desc("Allows using features specific to Hugo"),
llvm::cl::init(false), llvm::cl::cat(docCat));
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using mlir::tblgen::Operator; using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
// Commandline Options
//===----------------------------------------------------------------------===//
static cl::OptionCategory
docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
cl::opt<std::string>
stripPrefix("strip-prefix",
cl::desc("Strip prefix of the fully qualified names"),
cl::init("::mlir::"), cl::cat(docCat));
cl::opt<bool> allowHugoSpecificFeatures(
"allow-hugo-specific-features",
cl::desc("Allows using features specific to Hugo"), cl::init(false),
cl::cat(docCat));
void mlir::tblgen::emitSummary(StringRef summary, raw_ostream &os) { void mlir::tblgen::emitSummary(StringRef summary, raw_ostream &os) {
if (!summary.empty()) { if (!summary.empty()) {
llvm::StringRef trimmed = summary.trim(); StringRef trimmed = summary.trim();
char first = std::toupper(trimmed.front()); char first = std::toupper(trimmed.front());
llvm::StringRef rest = trimmed.drop_front(); StringRef rest = trimmed.drop_front();
os << "\n_" << first << rest << "_\n\n"; os << "\n_" << first << rest << "_\n\n";
} }
} }
@ -152,10 +152,10 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
effectName.consume_front("::"); effectName.consume_front("::");
effectName.consume_front("mlir::"); effectName.consume_front("mlir::");
std::string effectStr; std::string effectStr;
llvm::raw_string_ostream os(effectStr); raw_string_ostream os(effectStr);
os << effectName << "{"; os << effectName << "{";
auto list = trait.getDef().getValueAsListOfDefs("effects"); auto list = trait.getDef().getValueAsListOfDefs("effects");
llvm::interleaveComma(list, os, [&](const Record *rec) { interleaveComma(list, os, [&](const Record *rec) {
StringRef effect = rec->getValueAsString("effect"); StringRef effect = rec->getValueAsString("effect");
effect.consume_front("::"); effect.consume_front("::");
effect.consume_front("mlir::"); effect.consume_front("mlir::");
@ -163,7 +163,7 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
}); });
os << "}"; os << "}";
effects.insert(backticks(effectStr)); effects.insert(backticks(effectStr));
name.append(llvm::formatv(" ({0})", traitName).str()); name.append(formatv(" ({0})", traitName).str());
} }
interfaces.insert(backticks(name)); interfaces.insert(backticks(name));
continue; continue;
@ -172,15 +172,15 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
traits.insert(backticks(name)); traits.insert(backticks(name));
} }
if (!traits.empty()) { if (!traits.empty()) {
llvm::interleaveComma(traits, os << "\nTraits: "); interleaveComma(traits, os << "\nTraits: ");
os << "\n"; os << "\n";
} }
if (!interfaces.empty()) { if (!interfaces.empty()) {
llvm::interleaveComma(interfaces, os << "\nInterfaces: "); interleaveComma(interfaces, os << "\nInterfaces: ");
os << "\n"; os << "\n";
} }
if (!effects.empty()) { if (!effects.empty()) {
llvm::interleaveComma(effects, os << "\nEffects: "); interleaveComma(effects, os << "\nEffects: ");
os << "\n"; os << "\n";
} }
} }
@ -196,7 +196,7 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) {
std::string classNameStr = op.getQualCppClassName(); std::string classNameStr = op.getQualCppClassName();
StringRef className = classNameStr; StringRef className = classNameStr;
(void)className.consume_front(stripPrefix); (void)className.consume_front(stripPrefix);
os << llvm::formatv("### `{0}` ({1})\n", op.getOperationName(), className); os << formatv("### `{0}` ({1})\n", op.getOperationName(), className);
// Emit the summary, syntax, and description if present. // Emit the summary, syntax, and description if present.
if (op.hasSummary()) if (op.hasSummary())
@ -287,7 +287,7 @@ static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n"; os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
emitSourceLink(recordKeeper.getInputFilename(), os); emitSourceLink(recordKeeper.getInputFilename(), os);
for (const llvm::Record *opDef : opDefs) for (const Record *opDef : opDefs)
emitOpDoc(Operator(opDef), os); emitOpDoc(Operator(opDef), os);
} }
@ -339,7 +339,7 @@ static void emitAttrOrTypeDefAssemblyFormat(const AttrOrTypeDef &def,
} }
static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) { static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
os << llvm::formatv("### {0}\n", def.getCppClassName()); os << formatv("### {0}\n", def.getCppClassName());
// Emit the summary if present. // Emit the summary if present.
if (def.hasSummary()) if (def.hasSummary())
@ -376,7 +376,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
auto defs = recordKeeper.getAllDerivedDefinitions(recordTypeName); auto defs = recordKeeper.getAllDerivedDefinitions(recordTypeName);
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n"; os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const llvm::Record *def : defs) for (const Record *def : defs)
emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os); emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os);
} }
@ -385,7 +385,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) { static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
os << llvm::formatv("### {0}\n", def.getEnumClassName()); os << formatv("### {0}\n", def.getEnumClassName());
// Emit the summary if present. // Emit the summary if present.
if (!def.getSummary().empty()) if (!def.getSummary().empty())
@ -406,8 +406,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n"; os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const llvm::Record *def : for (const Record *def : recordKeeper.getAllDerivedDefinitions("EnumAttr"))
recordKeeper.getAllDerivedDefinitions("EnumAttr"))
emitEnumDoc(EnumAttr(def), os); emitEnumDoc(EnumAttr(def), os);
} }
@ -431,7 +430,7 @@ struct OpDocGroup {
static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn, static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
raw_ostream &os) { raw_ostream &os) {
std::string str; std::string str;
llvm::raw_string_ostream ss(str); raw_string_ostream ss(str);
fn(ss); fn(ss);
for (StringRef x : llvm::split(str, "\n")) { for (StringRef x : llvm::split(str, "\n")) {
if (nest && x.starts_with("#")) if (nest && x.starts_with("#"))
@ -507,7 +506,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
emitIfNotEmpty(dialect.getDescription(), os); emitIfNotEmpty(dialect.getDescription(), os);
// Generate a TOC marker except if description already contains one. // Generate a TOC marker except if description already contains one.
llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline); Regex r("^[[:space:]]*\\[TOC\\]$", Regex::RegexFlags::Newline);
if (!r.match(dialect.getDescription())) if (!r.match(dialect.getDescription()))
os << "[TOC]\n\n"; os << "[TOC]\n\n";
@ -537,17 +536,15 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<TypeDef> dialectTypeDefs; std::vector<TypeDef> dialectTypeDefs;
std::vector<EnumAttr> dialectEnums; std::vector<EnumAttr> dialectEnums;
llvm::SmallDenseSet<const Record *> seen; SmallDenseSet<const Record *> seen;
auto addIfNotSeen = [&](const llvm::Record *record, const auto &def, auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) {
auto &vec) {
if (seen.insert(record).second) { if (seen.insert(record).second) {
vec.push_back(def); vec.push_back(def);
return true; return true;
} }
return false; return false;
}; };
auto addIfInDialect = [&](const llvm::Record *record, const auto &def, auto addIfInDialect = [&](const Record *record, const auto &def, auto &vec) {
auto &vec) {
return def.getDialect() == *dialect && addIfNotSeen(record, def, vec); return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
}; };

View File

@ -28,6 +28,9 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::StringMap;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// VariableElement // VariableElement
@ -404,7 +407,7 @@ struct OperationFormat {
StringRef opCppClassName; StringRef opCppClassName;
/// A map of buildable types to indices. /// A map of buildable types to indices.
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes; llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes;
/// The index of the buildable type, if valid, for every operand and result. /// The index of the buildable type, if valid, for every operand and result.
std::vector<TypeResolution> operandTypes, resultTypes; std::vector<TypeResolution> operandTypes, resultTypes;
@ -891,8 +894,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) { } else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
const NamedAttribute *var = attr->getVar(); const NamedAttribute *var = attr->getVar();
body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), body << formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), var->name);
var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) { } else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name; StringRef name = operand->getVar()->name;
@ -910,19 +912,19 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> " << " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
<< name << "Operands(&" << name << "RawOperand, 1);"; << name << "Operands(&" << name << "RawOperand, 1);";
} }
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n" body << formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
" (void){0}OperandsLoc;\n", " (void){0}OperandsLoc;\n",
name); name);
} else if (auto *region = dyn_cast<RegionVariable>(element)) { } else if (auto *region = dyn_cast<RegionVariable>(element)) {
StringRef name = region->getVar()->name; StringRef name = region->getVar()->name;
if (region->getVar()->isVariadic()) { if (region->getVar()->isVariadic()) {
body << llvm::formatv( body << formatv(
" ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> " " ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
"{0}Regions;\n", "{0}Regions;\n",
name); name);
} else { } else {
body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = " body << formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
"std::make_unique<::mlir::Region>();\n", "std::make_unique<::mlir::Region>();\n",
name); name);
} }
@ -930,11 +932,11 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) { } else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
StringRef name = successor->getVar()->name; StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic()) { if (successor->getVar()->isVariadic()) {
body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> " body << formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
"{0}Successors;\n", "{0}Successors;\n",
name); name);
} else { } else {
body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name); body << formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
} }
} else if (auto *dir = dyn_cast<TypeDirective>(element)) { } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
@ -944,8 +946,8 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n"; body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
else else
body body
<< llvm::formatv(" ::mlir::Type {0}RawType{{};\n", name) << formatv(" ::mlir::Type {0}RawType{{};\n", name)
<< llvm::formatv( << formatv(
" ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n", " ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
name); name);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) { } else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
@ -969,27 +971,27 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
StringRef name = operand->getVar()->name; StringRef name = operand->getVar()->name;
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv("{0}OperandGroups", name); body << formatv("{0}OperandGroups", name);
else if (lengthKind == ArgumentLengthKind::Variadic) else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Operands", name); body << formatv("{0}Operands", name);
else if (lengthKind == ArgumentLengthKind::Optional) else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Operand", name); body << formatv("{0}Operand", name);
else else
body << formatv("{0}RawOperand", name); body << formatv("{0}RawOperand", name);
} else if (auto *region = dyn_cast<RegionVariable>(param)) { } else if (auto *region = dyn_cast<RegionVariable>(param)) {
StringRef name = region->getVar()->name; StringRef name = region->getVar()->name;
if (region->getVar()->isVariadic()) if (region->getVar()->isVariadic())
body << llvm::formatv("{0}Regions", name); body << formatv("{0}Regions", name);
else else
body << llvm::formatv("*{0}Region", name); body << formatv("*{0}Region", name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(param)) { } else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
StringRef name = successor->getVar()->name; StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic()) if (successor->getVar()->isVariadic())
body << llvm::formatv("{0}Successors", name); body << formatv("{0}Successors", name);
else else
body << llvm::formatv("{0}Successor", name); body << formatv("{0}Successor", name);
} else if (auto *dir = dyn_cast<RefDirective>(param)) { } else if (auto *dir = dyn_cast<RefDirective>(param)) {
genCustomParameterParser(dir->getArg(), body); genCustomParameterParser(dir->getArg(), body);
@ -998,11 +1000,11 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
ArgumentLengthKind lengthKind; ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getArg(), lengthKind); StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv("{0}TypeGroups", listName); body << formatv("{0}TypeGroups", listName);
else if (lengthKind == ArgumentLengthKind::Variadic) else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Types", listName); body << formatv("{0}Types", listName);
else if (lengthKind == ArgumentLengthKind::Optional) else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Type", listName); body << formatv("{0}Type", listName);
else else
body << formatv("{0}RawType", listName); body << formatv("{0}RawType", listName);
@ -1013,7 +1015,7 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
body << tgfmt(string->getValue(), &ctx); body << tgfmt(string->getValue(), &ctx);
} else if (auto *property = dyn_cast<PropertyVariable>(param)) { } else if (auto *property = dyn_cast<PropertyVariable>(param)) {
body << llvm::formatv("result.getOrAddProperties<Properties>().{0}", body << formatv("result.getOrAddProperties<Properties>().{0}",
property->getVar()->name); property->getVar()->name);
} else { } else {
llvm_unreachable("unknown custom directive parameter"); llvm_unreachable("unknown custom directive parameter");
@ -1037,12 +1039,12 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
body << " " << var->name body << " " << var->name
<< "OperandsLoc = parser.getCurrentLocation();\n"; << "OperandsLoc = parser.getCurrentLocation();\n";
if (var->isOptional()) { if (var->isOptional()) {
body << llvm::formatv( body << formatv(
" ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> " " ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
"{0}Operand;\n", "{0}Operand;\n",
var->name); var->name);
} else if (var->isVariadicOfVariadic()) { } else if (var->isVariadicOfVariadic()) {
body << llvm::formatv(" " body << formatv(" "
"::llvm::SmallVector<::llvm::SmallVector<::mlir::" "::llvm::SmallVector<::llvm::SmallVector<::mlir::"
"OpAsmParser::UnresolvedOperand>> " "OpAsmParser::UnresolvedOperand>> "
"{0}OperandGroups;\n", "{0}OperandGroups;\n",
@ -1052,9 +1054,9 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
ArgumentLengthKind lengthKind; ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getArg(), lengthKind); StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) { if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); body << formatv(" ::mlir::Type {0}Type;\n", listName);
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv( body << formatv(
" ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> " " ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
"{0}TypeGroups;\n", "{0}TypeGroups;\n",
listName); listName);
@ -1064,7 +1066,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
if (auto *operand = dyn_cast<OperandVariable>(input)) { if (auto *operand = dyn_cast<OperandVariable>(input)) {
if (!operand->getVar()->isOptional()) if (!operand->getVar()->isOptional())
continue; continue;
body << llvm::formatv( body << formatv(
" {0} {1}Operand = {1}Operands.empty() ? {0}() : " " {0} {1}Operand = {1}Operands.empty() ? {0}() : "
"{1}Operands[0];\n", "{1}Operands[0];\n",
"::std::optional<::mlir::OpAsmParser::UnresolvedOperand>", "::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
@ -1074,7 +1076,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
ArgumentLengthKind lengthKind; ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(type->getArg(), lengthKind); StringRef listName = getTypeListName(type->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) { if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? " body << formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
"::mlir::Type() : {0}Types[0];\n", "::mlir::Type() : {0}Types[0];\n",
listName); listName);
} }
@ -1101,23 +1103,23 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
if (auto *attr = dyn_cast<AttributeVariable>(param)) { if (auto *attr = dyn_cast<AttributeVariable>(param)) {
const NamedAttribute *var = attr->getVar(); const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional() || var->attr.hasDefaultValue()) if (var->attr.isOptional() || var->attr.hasDefaultValue())
body << llvm::formatv(" if ({0}Attr)\n ", var->name); body << formatv(" if ({0}Attr)\n ", var->name);
if (useProperties) { if (useProperties) {
body << formatv( body << formatv(
" result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n", " result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
var->name, opCppClassName); var->name, opCppClassName);
} else { } else {
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", body << formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
var->name); var->name);
} }
} else if (auto *operand = dyn_cast<OperandVariable>(param)) { } else if (auto *operand = dyn_cast<OperandVariable>(param)) {
const NamedTypeConstraint *var = operand->getVar(); const NamedTypeConstraint *var = operand->getVar();
if (var->isOptional()) { if (var->isOptional()) {
body << llvm::formatv(" if ({0}Operand.has_value())\n" body << formatv(" if ({0}Operand.has_value())\n"
" {0}Operands.push_back(*{0}Operand);\n", " {0}Operands.push_back(*{0}Operand);\n",
var->name); var->name);
} else if (var->isVariadicOfVariadic()) { } else if (var->isVariadicOfVariadic()) {
body << llvm::formatv( body << formatv(
" for (const auto &subRange : {0}OperandGroups) {{\n" " for (const auto &subRange : {0}OperandGroups) {{\n"
" {0}Operands.append(subRange.begin(), subRange.end());\n" " {0}Operands.append(subRange.begin(), subRange.end());\n"
" {0}OperandGroupSizes.push_back(subRange.size());\n" " {0}OperandGroupSizes.push_back(subRange.size());\n"
@ -1128,11 +1130,11 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
ArgumentLengthKind lengthKind; ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getArg(), lengthKind); StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) { if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" if ({0}Type)\n" body << formatv(" if ({0}Type)\n"
" {0}Types.push_back({0}Type);\n", " {0}Types.push_back({0}Type);\n",
listName); listName);
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv( body << formatv(
" for (const auto &subRange : {0}TypeGroups)\n" " for (const auto &subRange : {0}TypeGroups)\n"
" {0}Types.append(subRange.begin(), subRange.end());\n", " {0}Types.append(subRange.begin(), subRange.end());\n",
listName); listName);
@ -1460,7 +1462,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
body << " if (" << attrVar->getVar()->name << "Attr) {\n"; body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *propVar = dyn_cast<PropertyVariable>(firstElement)) { } else if (auto *propVar = dyn_cast<PropertyVariable>(firstElement)) {
genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false); genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false);
body << llvm::formatv("if ({0}PropParseResult.has_value() && " body << formatv("if ({0}PropParseResult.has_value() && "
"succeeded(*{0}PropParseResult)) ", "succeeded(*{0}PropParseResult)) ",
propVar->getVar()->name) propVar->getVar()->name)
<< " {\n"; << " {\n";
@ -1477,13 +1479,12 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
genElementParser(regionVar, body, attrTypeCtx); genElementParser(regionVar, body, attrTypeCtx);
body << " if (!" << region->name << "Regions.empty()) {\n"; body << " if (!" << region->name << "Regions.empty()) {\n";
} else { } else {
body << llvm::formatv(optionalRegionParserCode, region->name); body << formatv(optionalRegionParserCode, region->name);
body << " if (!" << region->name << "Region->empty()) {\n "; body << " if (!" << region->name << "Region->empty()) {\n ";
if (hasImplicitTermTrait) if (hasImplicitTermTrait)
body << llvm::formatv(regionEnsureTerminatorParserCode, region->name); body << formatv(regionEnsureTerminatorParserCode, region->name);
else if (hasSingleBlockTrait) else if (hasSingleBlockTrait)
body << llvm::formatv(regionEnsureSingleBlockParserCode, body << formatv(regionEnsureSingleBlockParserCode, region->name);
region->name);
} }
} else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) { } else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n"; body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
@ -1575,24 +1576,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name; StringRef name = operand->getVar()->name;
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv(variadicOfVariadicOperandParserCode, name); body << formatv(variadicOfVariadicOperandParserCode, name);
else if (lengthKind == ArgumentLengthKind::Variadic) else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv(variadicOperandParserCode, name); body << formatv(variadicOperandParserCode, name);
else if (lengthKind == ArgumentLengthKind::Optional) else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(optionalOperandParserCode, name); body << formatv(optionalOperandParserCode, name);
else else
body << formatv(operandParserCode, name); body << formatv(operandParserCode, name);
} else if (auto *region = dyn_cast<RegionVariable>(element)) { } else if (auto *region = dyn_cast<RegionVariable>(element)) {
bool isVariadic = region->getVar()->isVariadic(); bool isVariadic = region->getVar()->isVariadic();
body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode, body << formatv(isVariadic ? regionListParserCode : regionParserCode,
region->getVar()->name); region->getVar()->name);
if (hasImplicitTermTrait) if (hasImplicitTermTrait)
body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode body << formatv(isVariadic ? regionListEnsureTerminatorParserCode
: regionEnsureTerminatorParserCode, : regionEnsureTerminatorParserCode,
region->getVar()->name); region->getVar()->name);
else if (hasSingleBlockTrait) else if (hasSingleBlockTrait)
body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode body << formatv(isVariadic ? regionListEnsureSingleBlockParserCode
: regionEnsureSingleBlockParserCode, : regionEnsureSingleBlockParserCode,
region->getVar()->name); region->getVar()->name);
@ -1631,24 +1632,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
<< " return ::mlir::failure();\n"; << " return ::mlir::failure();\n";
} else if (isa<RegionsDirective>(element)) { } else if (isa<RegionsDirective>(element)) {
body << llvm::formatv(regionListParserCode, "full"); body << formatv(regionListParserCode, "full");
if (hasImplicitTermTrait) if (hasImplicitTermTrait)
body << llvm::formatv(regionListEnsureTerminatorParserCode, "full"); body << formatv(regionListEnsureTerminatorParserCode, "full");
else if (hasSingleBlockTrait) else if (hasSingleBlockTrait)
body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full"); body << formatv(regionListEnsureSingleBlockParserCode, "full");
} else if (isa<SuccessorsDirective>(element)) { } else if (isa<SuccessorsDirective>(element)) {
body << llvm::formatv(successorListParserCode, "full"); body << formatv(successorListParserCode, "full");
} else if (auto *dir = dyn_cast<TypeDirective>(element)) { } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind; ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getArg(), lengthKind); StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName); body << formatv(variadicOfVariadicTypeParserCode, listName);
} else if (lengthKind == ArgumentLengthKind::Variadic) { } else if (lengthKind == ArgumentLengthKind::Variadic) {
body << llvm::formatv(variadicTypeParserCode, listName); body << formatv(variadicTypeParserCode, listName);
} else if (lengthKind == ArgumentLengthKind::Optional) { } else if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(optionalTypeParserCode, listName); body << formatv(optionalTypeParserCode, listName);
} else { } else {
const char *parserCode = const char *parserCode =
dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode; dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
@ -1903,14 +1904,14 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
if (!operand.isVariadicOfVariadic()) if (!operand.isVariadicOfVariadic())
continue; continue;
if (op.getDialect().usePropertiesForAttributes()) { if (op.getDialect().usePropertiesForAttributes()) {
body << llvm::formatv( body << formatv(
" result.getOrAddProperties<{0}::Properties>().{1} = " " result.getOrAddProperties<{0}::Properties>().{1} = "
"parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n", "parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
op.getCppClassName(), op.getCppClassName(),
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
operand.name); operand.name);
} else { } else {
body << llvm::formatv( body << formatv(
" result.addAttribute(\"{0}\", " " result.addAttribute(\"{0}\", "
"parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));" "parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
"\n", "\n",
@ -2160,7 +2161,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
if (var->isVariadic()) if (var->isVariadic())
body << name << "().getTypes()"; body << name << "().getTypes()";
else if (var->isOptional()) else if (var->isOptional())
body << llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name); body << formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
else else
body << name << "().getType()"; body << name << "().getType()";
@ -2195,8 +2196,7 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
static void genRegionPrinter(const Twine &regionName, MethodBody &body, static void genRegionPrinter(const Twine &regionName, MethodBody &body,
bool hasImplicitTermTrait) { bool hasImplicitTermTrait) {
if (hasImplicitTermTrait) if (hasImplicitTermTrait)
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode, body << formatv(regionSingleBlockImplicitTerminatorPrinterCode, regionName);
regionName);
else else
body << " _odsPrinter.printRegion(" << regionName << ");\n"; body << " _odsPrinter.printRegion(" << regionName << ");\n";
} }
@ -2220,12 +2220,12 @@ static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
auto *operand = dyn_cast<OperandVariable>(arg); auto *operand = dyn_cast<OperandVariable>(arg);
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar(); auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
if (var->isVariadicOfVariadic()) if (var->isVariadicOfVariadic())
return body << llvm::formatv("{0}().join().getTypes()", return body << formatv("{0}().join().getTypes()",
op.getGetterName(var->name)); op.getGetterName(var->name));
if (var->isVariadic()) if (var->isVariadic())
return body << op.getGetterName(var->name) << "().getTypes()"; return body << op.getGetterName(var->name) << "().getTypes()";
if (var->isOptional()) if (var->isOptional())
return body << llvm::formatv( return body << formatv(
"({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : " "({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
"::llvm::ArrayRef<::mlir::Type>())", "::llvm::ArrayRef<::mlir::Type>())",
op.getGetterName(var->name)); op.getGetterName(var->name));
@ -2242,7 +2242,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr); const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
std::vector<EnumAttrCase> cases = enumAttr.getAllCases(); std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
body << llvm::formatv(enumAttrBeginPrinterCode, body << formatv(enumAttrBeginPrinterCode,
(var->attr.isOptional() ? "*" : "") + (var->attr.isOptional() ? "*" : "") +
op.getGetterName(var->name), op.getGetterName(var->name),
enumAttr.getSymbolToStringFnName()); enumAttr.getSymbolToStringFnName());
@ -2276,9 +2276,8 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
if (nonKeywordCases.test(it.index())) if (nonKeywordCases.test(it.index()))
continue; continue;
StringRef symbol = it.value().getSymbol(); StringRef symbol = it.value().getSymbol();
body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName, body << formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
llvm::isDigit(symbol.front()) ? ("_" + symbol) llvm::isDigit(symbol.front()) ? ("_" + symbol) : symbol);
: symbol);
} }
body << " _odsPrinter << caseValueStr;\n" body << " _odsPrinter << caseValueStr;\n"
" break;\n" " break;\n"
@ -2584,7 +2583,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) { } else if (auto *dir = dyn_cast<TypeDirective>(element)) {
if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) { if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
if (operand->getVar()->isVariadicOfVariadic()) { if (operand->getVar()->isVariadicOfVariadic()) {
body << llvm::formatv( body << formatv(
" ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, " " ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
"[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << " "[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
"types << \")\"; });\n", "types << \")\"; });\n",
@ -2710,7 +2709,7 @@ private:
/// Verify the state of operation operands within the format. /// Verify the state of operation operands within the format.
LogicalResult LogicalResult
verifyOperands(SMLoc loc, verifyOperands(SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver); StringMap<TypeResolutionInstance> &variableTyResolver);
/// Verify the state of operation regions within the format. /// Verify the state of operation regions within the format.
LogicalResult verifyRegions(SMLoc loc); LogicalResult verifyRegions(SMLoc loc);
@ -2718,7 +2717,7 @@ private:
/// Verify the state of operation results within the format. /// Verify the state of operation results within the format.
LogicalResult LogicalResult
verifyResults(SMLoc loc, verifyResults(SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver); StringMap<TypeResolutionInstance> &variableTyResolver);
/// Verify the state of operation successors within the format. /// Verify the state of operation successors within the format.
LogicalResult verifySuccessors(SMLoc loc); LogicalResult verifySuccessors(SMLoc loc);
@ -2730,18 +2729,17 @@ private:
/// resolution. /// resolution.
void handleAllTypesMatchConstraint( void handleAllTypesMatchConstraint(
ArrayRef<StringRef> values, ArrayRef<StringRef> values,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver); StringMap<TypeResolutionInstance> &variableTyResolver);
/// Check for inferable type resolution given all operands, and or results, /// Check for inferable type resolution given all operands, and or results,
/// have the same type. If 'includeResults' is true, the results also have the /// have the same type. If 'includeResults' is true, the results also have the
/// same type as all of the operands. /// same type as all of the operands.
void handleSameTypesConstraint( void handleSameTypesConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver, StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults); bool includeResults);
/// Check for inferable type resolution based on another operand, result, or /// Check for inferable type resolution based on another operand, result, or
/// attribute. /// attribute.
void handleTypesMatchConstraint( void handleTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver, StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
const llvm::Record &def);
/// Returns an argument or attribute with the given name that has been seen /// Returns an argument or attribute with the given name that has been seen
/// within the format. /// within the format.
@ -2794,9 +2792,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
"custom assembly format"); "custom assembly format");
// Check for any type traits that we can use for inferring types. // Check for any type traits that we can use for inferring types.
llvm::StringMap<TypeResolutionInstance> variableTyResolver; StringMap<TypeResolutionInstance> variableTyResolver;
for (const Trait &trait : op.getTraits()) { for (const Trait &trait : op.getTraits()) {
const llvm::Record &def = trait.getDef(); const Record &def = trait.getDef();
if (def.isSubClassOf("AllTypesMatch")) { if (def.isSubClassOf("AllTypesMatch")) {
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"), handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
variableTyResolver); variableTyResolver);
@ -2995,8 +2993,7 @@ OpFormatParser::verifyAttributeColonType(SMLoc loc,
return false; return false;
// If we encounter `:`, the range is known to be invalid. // If we encounter `:`, the range is known to be invalid.
(void)emitError( (void)emitError(
loc, loc, formatv("format ambiguity caused by `:` literal found after "
llvm::formatv("format ambiguity caused by `:` literal found after "
"attribute `{0}` which does not have a buildable type", "attribute `{0}` which does not have a buildable type",
cast<AttributeVariable>(base)->getVar()->name)); cast<AttributeVariable>(base)->getVar()->name));
return true; return true;
@ -3018,7 +3015,7 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc,
return false; return false;
(void)emitErrorAndNote( (void)emitErrorAndNote(
loc, loc,
llvm::formatv("format ambiguity caused by `attr-dict` directive " formatv("format ambiguity caused by `attr-dict` directive "
"followed by region `{0}`", "followed by region `{0}`",
region->getVar()->name), region->getVar()->name),
"try using `attr-dict-with-keyword` instead"); "try using `attr-dict-with-keyword` instead");
@ -3028,7 +3025,7 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc,
} }
LogicalResult OpFormatParser::verifyOperands( LogicalResult OpFormatParser::verifyOperands(
SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
// Check that all of the operands are within the format, and their types can // Check that all of the operands are within the format, and their types can
// be inferred. // be inferred.
auto &buildableTypes = fmt.buildableTypes; auto &buildableTypes = fmt.buildableTypes;
@ -3093,7 +3090,7 @@ LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
} }
LogicalResult OpFormatParser::verifyResults( LogicalResult OpFormatParser::verifyResults(
SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
// If we format all of the types together, there is nothing to check. // If we format all of the types together, there is nothing to check.
if (fmt.allResultTypes) if (fmt.allResultTypes)
return success(); return success();
@ -3197,7 +3194,7 @@ OpFormatParser::verifyOIListElements(SMLoc loc,
void OpFormatParser::handleAllTypesMatchConstraint( void OpFormatParser::handleAllTypesMatchConstraint(
ArrayRef<StringRef> values, ArrayRef<StringRef> values,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) { StringMap<TypeResolutionInstance> &variableTyResolver) {
for (unsigned i = 0, e = values.size(); i != e; ++i) { for (unsigned i = 0, e = values.size(); i != e; ++i) {
// Check to see if this value matches a resolved operand or result type. // Check to see if this value matches a resolved operand or result type.
ConstArgument arg = findSeenArg(values[i]); ConstArgument arg = findSeenArg(values[i]);
@ -3213,7 +3210,7 @@ void OpFormatParser::handleAllTypesMatchConstraint(
} }
void OpFormatParser::handleSameTypesConstraint( void OpFormatParser::handleSameTypesConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver, StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults) { bool includeResults) {
const NamedTypeConstraint *resolver = nullptr; const NamedTypeConstraint *resolver = nullptr;
int resolvedIt = -1; int resolvedIt = -1;
@ -3238,8 +3235,7 @@ void OpFormatParser::handleSameTypesConstraint(
} }
void OpFormatParser::handleTypesMatchConstraint( void OpFormatParser::handleTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver, StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
const llvm::Record &def) {
StringRef lhsName = def.getValueAsString("lhs"); StringRef lhsName = def.getValueAsString("lhs");
StringRef rhsName = def.getValueAsString("rhs"); StringRef rhsName = def.getValueAsString("rhs");
StringRef transformer = def.getValueAsString("transformer"); StringRef transformer = def.getValueAsString("transformer");

View File

@ -41,7 +41,7 @@ static std::string getOperationName(const Record &def) {
auto opName = def.getValueAsString("opName"); auto opName = def.getValueAsString("opName");
if (prefix.empty()) if (prefix.empty())
return std::string(opName); return std::string(opName);
return std::string(llvm::formatv("{0}.{1}", prefix, opName)); return std::string(formatv("{0}.{1}", prefix, opName));
} }
std::vector<const Record *> std::vector<const Record *>
@ -50,7 +50,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
if (!classDef) if (!classDef)
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n"); PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter); Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
std::vector<const Record *> defs; std::vector<const Record *> defs;
for (const auto &def : recordKeeper.getDefs()) { for (const auto &def : recordKeeper.getDefs()) {
if (!def.second->isSubClassOf(classDef)) if (!def.second->isSubClassOf(classDef))
@ -70,7 +70,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
} }
bool mlir::tblgen::isPythonReserved(StringRef str) { bool mlir::tblgen::isPythonReserved(StringRef str) {
static llvm::StringSet<> reserved({ static StringSet<> reserved({
"False", "None", "True", "and", "as", "assert", "async", "False", "None", "True", "and", "as", "assert", "async",
"await", "break", "class", "continue", "def", "del", "elif", "await", "break", "class", "continue", "def", "del", "elif",
"else", "except", "finally", "for", "from", "global", "if", "else", "except", "finally", "for", "from", "global", "if",
@ -86,8 +86,8 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
} }
void mlir::tblgen::shardOpDefinitions( void mlir::tblgen::shardOpDefinitions(
ArrayRef<const llvm::Record *> defs, ArrayRef<const Record *> defs,
SmallVectorImpl<ArrayRef<const llvm::Record *>> &shardedDefs) { SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) {
assert(opShardCount > 0 && "expected a positive shard count"); assert(opShardCount > 0 && "expected a positive shard count");
if (opShardCount == 1) { if (opShardCount == 1) {
shardedDefs.push_back(defs); shardedDefs.push_back(defs);

View File

@ -23,6 +23,8 @@
#include "llvm/TableGen/TableGenBackend.h" #include "llvm/TableGen/TableGenBackend.h"
using namespace mlir; using namespace mlir;
using llvm::Record;
using llvm::RecordKeeper;
using mlir::tblgen::Interface; using mlir::tblgen::Interface;
using mlir::tblgen::InterfaceMethod; using mlir::tblgen::InterfaceMethod;
using mlir::tblgen::OpInterface; using mlir::tblgen::OpInterface;
@ -61,14 +63,13 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
/// Get an array of all OpInterface definitions but exclude those subclassing /// Get an array of all OpInterface definitions but exclude those subclassing
/// "DeclareOpInterfaceMethods". /// "DeclareOpInterfaceMethods".
static std::vector<const llvm::Record *> static std::vector<const Record *>
getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper, getAllInterfaceDefinitions(const RecordKeeper &recordKeeper, StringRef name) {
StringRef name) { std::vector<const Record *> defs =
std::vector<const llvm::Record *> defs =
recordKeeper.getAllDerivedDefinitions((name + "Interface").str()); recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
std::string declareName = ("Declare" + name + "InterfaceMethods").str(); std::string declareName = ("Declare" + name + "InterfaceMethods").str();
llvm::erase_if(defs, [&](const llvm::Record *def) { llvm::erase_if(defs, [&](const Record *def) {
// Ignore any "declare methods" interfaces. // Ignore any "declare methods" interfaces.
if (def->isSubClassOf(declareName)) if (def->isSubClassOf(declareName))
return true; return true;
@ -88,7 +89,7 @@ public:
bool emitInterfaceDocs(); bool emitInterfaceDocs();
protected: protected:
InterfaceGenerator(std::vector<const llvm::Record *> &&defs, raw_ostream &os) InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os)
: defs(std::move(defs)), os(os) {} : defs(std::move(defs)), os(os) {}
void emitConceptDecl(const Interface &interface); void emitConceptDecl(const Interface &interface);
@ -99,7 +100,7 @@ protected:
void emitInterfaceDecl(const Interface &interface); void emitInterfaceDecl(const Interface &interface);
/// The set of interface records to emit. /// The set of interface records to emit.
std::vector<const llvm::Record *> defs; std::vector<const Record *> defs;
// The stream to emit to. // The stream to emit to.
raw_ostream &os; raw_ostream &os;
/// The C++ value type of the interface, e.g. Operation*. /// The C++ value type of the interface, e.g. Operation*.
@ -118,7 +119,7 @@ protected:
/// A specialized generator for attribute interfaces. /// A specialized generator for attribute interfaces.
struct AttrInterfaceGenerator : public InterfaceGenerator { struct AttrInterfaceGenerator : public InterfaceGenerator {
AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) { : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
valueType = "::mlir::Attribute"; valueType = "::mlir::Attribute";
interfaceBaseType = "AttributeInterface"; interfaceBaseType = "AttributeInterface";
@ -133,7 +134,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
}; };
/// A specialized generator for operation interfaces. /// A specialized generator for operation interfaces.
struct OpInterfaceGenerator : public InterfaceGenerator { struct OpInterfaceGenerator : public InterfaceGenerator {
OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) { : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
valueType = "::mlir::Operation *"; valueType = "::mlir::Operation *";
interfaceBaseType = "OpInterface"; interfaceBaseType = "OpInterface";
@ -149,7 +150,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
}; };
/// A specialized generator for type interfaces. /// A specialized generator for type interfaces.
struct TypeInterfaceGenerator : public InterfaceGenerator { struct TypeInterfaceGenerator : public InterfaceGenerator {
TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) { : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
valueType = "::mlir::Type"; valueType = "::mlir::Type";
interfaceBaseType = "TypeInterface"; interfaceBaseType = "TypeInterface";
@ -607,13 +608,13 @@ bool InterfaceGenerator::emitInterfaceDecls() {
llvm::emitSourceFileHeader("Interface Declarations", os); llvm::emitSourceFileHeader("Interface Declarations", os);
// Sort according to ID, so defs are emitted in the order in which they appear // Sort according to ID, so defs are emitted in the order in which they appear
// in the Tablegen file. // in the Tablegen file.
std::vector<const llvm::Record *> sortedDefs(defs); std::vector<const Record *> sortedDefs(defs);
llvm::sort(sortedDefs, [](const llvm::Record *lhs, const llvm::Record *rhs) { llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID(); return lhs->getID() < rhs->getID();
}); });
for (const llvm::Record *def : sortedDefs) for (const Record *def : sortedDefs)
emitInterfaceDecl(Interface(def)); emitInterfaceDecl(Interface(def));
for (const llvm::Record *def : sortedDefs) for (const Record *def : sortedDefs)
emitModelMethodsDef(Interface(def)); emitModelMethodsDef(Interface(def));
return false; return false;
} }
@ -622,8 +623,7 @@ bool InterfaceGenerator::emitInterfaceDecls() {
// GEN: Interface documentation // GEN: Interface documentation
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void emitInterfaceDoc(const llvm::Record &interfaceDef, static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
raw_ostream &os) {
Interface interface(&interfaceDef); Interface interface(&interfaceDef);
// Emit the interface name followed by the description. // Emit the interface name followed by the description.
@ -684,15 +684,15 @@ struct InterfaceGenRegistration {
genDefDesc(("Generate " + genDesc + " interface definitions").str()), genDefDesc(("Generate " + genDesc + " interface definitions").str()),
genDocDesc(("Generate " + genDesc + " interface documentation").str()), genDocDesc(("Generate " + genDesc + " interface documentation").str()),
genDecls(genDeclArg, genDeclDesc, genDecls(genDeclArg, genDeclDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDecls(); return GeneratorT(records, os).emitInterfaceDecls();
}), }),
genDefs(genDefArg, genDefDesc, genDefs(genDefArg, genDefDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDefs(); return GeneratorT(records, os).emitInterfaceDefs();
}), }),
genDocs(genDocArg, genDocDesc, genDocs(genDocArg, genDocDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDocs(); return GeneratorT(records, os).emitInterfaceDocs();
}) {} }) {}

View File

@ -23,6 +23,9 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
/// File header and includes. /// File header and includes.
/// {0} is the dialect namespace. /// {0} is the dialect namespace.
@ -315,9 +318,9 @@ static std::string sanitizeName(StringRef name) {
} }
static std::string attrSizedTraitForKind(const char *kind) { static std::string attrSizedTraitForKind(const char *kind) {
return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
llvm::StringRef(kind).take_front().upper(), StringRef(kind).take_front().upper(),
llvm::StringRef(kind).drop_front()); StringRef(kind).drop_front());
} }
/// Emits accessors to "elements" of an Op definition. Currently, the supported /// Emits accessors to "elements" of an Op definition. Currently, the supported
@ -328,15 +331,14 @@ static void emitElementAccessors(
unsigned numVariadicGroups, unsigned numElements, unsigned numVariadicGroups, unsigned numElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) { getElement) {
assert(llvm::is_contained( assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && kind) &&
"unsupported kind"); "unsupported kind");
// Traits indicating how to process variadic elements. // Traits indicating how to process variadic elements.
std::string sameSizeTrait = std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", StringRef(kind).take_front().upper(),
llvm::StringRef(kind).take_front().upper(), StringRef(kind).drop_front());
llvm::StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind); std::string attrSizedTrait = attrSizedTraitForKind(kind);
// If there is only one variable-length element group, its size can be // If there is only one variable-length element group, its size can be
@ -351,15 +353,14 @@ static void emitElementAccessors(
if (element.name.empty()) if (element.name.empty())
continue; continue;
if (element.isVariableLength()) { if (element.isVariableLength()) {
os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate os << formatv(element.isOptional() ? opOneOptionalTemplate
: opOneVariadicTemplate, : opOneVariadicTemplate,
sanitizeName(element.name), kind, numElements, i); sanitizeName(element.name), kind, numElements, i);
} else if (seenVariableLength) { } else if (seenVariableLength) {
os << llvm::formatv(opSingleAfterVariableTemplate, os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
sanitizeName(element.name), kind, numElements, i); kind, numElements, i);
} else { } else {
os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
i);
} }
} }
return; return;
@ -382,11 +383,10 @@ static void emitElementAccessors(
for (unsigned i = 0; i < numElements; ++i) { for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i); const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) { if (!element.name.empty()) {
os << llvm::formatv(opVariadicEqualPrefixTemplate, os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
sanitizeName(element.name), kind, numSimpleLength, kind, numSimpleLength, numVariadicGroups,
numVariadicGroups, numPrecedingSimple, numPrecedingSimple, numPrecedingVariadic);
numPrecedingVariadic); os << formatv(element.isVariableLength()
os << llvm::formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate ? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate, : opVariadicEqualSimpleTemplate,
kind); kind);
@ -412,9 +412,9 @@ static void emitElementAccessors(
trailing = "[0]"; trailing = "[0]";
else if (element.isOptional()) else if (element.isOptional())
trailing = std::string( trailing = std::string(
llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
kind, i, trailing); i, trailing);
} }
return; return;
} }
@ -459,27 +459,21 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
// Unit attributes are handled specially. // Unit attributes are handled specially.
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") { if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name);
namedAttr.name); os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name);
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
namedAttr.name);
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
namedAttr.name);
continue; continue;
} }
if (namedAttr.attr.isOptional()) { if (namedAttr.attr.isOptional()) {
os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, os << formatv(optionalAttributeGetterTemplate, sanitizedName,
namedAttr.name); namedAttr.name);
os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, os << formatv(optionalAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
namedAttr.name); namedAttr.name);
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
} else { } else {
os << llvm::formatv(attributeGetterTemplate, sanitizedName, os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
namedAttr.name); os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
os << llvm::formatv(attributeSetterTemplate, sanitizedName,
namedAttr.name);
// Non-optional attributes cannot be deleted. // Non-optional attributes cannot be deleted.
} }
} }
@ -595,7 +589,7 @@ static bool canInferType(const Operator &op) {
/// accept them as arguments. /// accept them as arguments.
static void static void
populateBuilderArgsResults(const Operator &op, populateBuilderArgsResults(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs) { SmallVectorImpl<std::string> &builderArgs) {
if (canInferType(op)) if (canInferType(op))
return; return;
@ -607,7 +601,7 @@ populateBuilderArgsResults(const Operator &op,
// to properly match the built-in result accessor. // to properly match the built-in result accessor.
name = "result"; name = "result";
} else { } else {
name = llvm::formatv("_gen_res_{0}", i); name = formatv("_gen_res_{0}", i);
} }
} }
name = sanitizeName(name); name = sanitizeName(name);
@ -620,14 +614,13 @@ populateBuilderArgsResults(const Operator &op,
/// appear in the `arguments` field of the op definition. Additionally, /// appear in the `arguments` field of the op definition. Additionally,
/// `operandNames` is populated with names of operands in their order of /// `operandNames` is populated with names of operands in their order of
/// appearance. /// appearance.
static void static void populateBuilderArgs(const Operator &op,
populateBuilderArgs(const Operator &op, SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &builderArgs, SmallVectorImpl<std::string> &operandNames) {
llvm::SmallVectorImpl<std::string> &operandNames) {
for (int i = 0, e = op.getNumArgs(); i < e; ++i) { for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
std::string name = op.getArgName(i).str(); std::string name = op.getArgName(i).str();
if (name.empty()) if (name.empty())
name = llvm::formatv("_gen_arg_{0}", i); name = formatv("_gen_arg_{0}", i);
name = sanitizeName(name); name = sanitizeName(name);
builderArgs.push_back(name); builderArgs.push_back(name);
if (!op.getArg(i).is<NamedAttribute *>()) if (!op.getArg(i).is<NamedAttribute *>())
@ -637,15 +630,16 @@ populateBuilderArgs(const Operator &op,
/// Populates `builderArgs` with the Python-compatible names of builder function /// Populates `builderArgs` with the Python-compatible names of builder function
/// successor arguments. Additionally, `successorArgNames` is also populated. /// successor arguments. Additionally, `successorArgNames` is also populated.
static void populateBuilderArgsSuccessors( static void
const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs, populateBuilderArgsSuccessors(const Operator &op,
llvm::SmallVectorImpl<std::string> &successorArgNames) { SmallVectorImpl<std::string> &builderArgs,
SmallVectorImpl<std::string> &successorArgNames) {
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
NamedSuccessor successor = op.getSuccessor(i); NamedSuccessor successor = op.getSuccessor(i);
std::string name = std::string(successor.name); std::string name = std::string(successor.name);
if (name.empty()) if (name.empty())
name = llvm::formatv("_gen_successor_{0}", i); name = formatv("_gen_successor_{0}", i);
name = sanitizeName(name); name = sanitizeName(name);
builderArgs.push_back(name); builderArgs.push_back(name);
successorArgNames.push_back(name); successorArgNames.push_back(name);
@ -658,9 +652,8 @@ static void populateBuilderArgsSuccessors(
/// operands and attributes in the same order as they appear in the `arguments` /// operands and attributes in the same order as they appear in the `arguments`
/// field. /// field.
static void static void
populateBuilderLinesAttr(const Operator &op, populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
llvm::ArrayRef<std::string> argNames, SmallVectorImpl<std::string> &builderLines) {
llvm::SmallVectorImpl<std::string> &builderLines) {
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)"); builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) { for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i); Argument arg = op.getArg(i);
@ -670,12 +663,12 @@ populateBuilderLinesAttr(const Operator &op,
// Unit attributes are handled specially. // Unit attributes are handled specially.
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") { if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, builderLines.push_back(
attribute->name, argNames[i])); formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
continue; continue;
} }
builderLines.push_back(llvm::formatv( builderLines.push_back(formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue() attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate ? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate, : initAttributeWithBuilderTemplate,
@ -686,30 +679,30 @@ populateBuilderLinesAttr(const Operator &op,
/// Populates `builderLines` with additional lines that are required in the /// Populates `builderLines` with additional lines that are required in the
/// builder to set up successors. successorArgNames is expected to correspond /// builder to set up successors. successorArgNames is expected to correspond
/// to the Python argument name for each successor on the op. /// to the Python argument name for each successor on the op.
static void populateBuilderLinesSuccessors( static void
const Operator &op, llvm::ArrayRef<std::string> successorArgNames, populateBuilderLinesSuccessors(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderLines) { ArrayRef<std::string> successorArgNames,
SmallVectorImpl<std::string> &builderLines) {
if (successorArgNames.empty()) { if (successorArgNames.empty()) {
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); builderLines.push_back(formatv(initSuccessorsTemplate, "None"));
return; return;
} }
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); builderLines.push_back(formatv(initSuccessorsTemplate, "[]"));
for (int i = 0, e = successorArgNames.size(); i < e; ++i) { for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
auto &argName = successorArgNames[i]; auto &argName = successorArgNames[i];
const NamedSuccessor &successor = op.getSuccessor(i); const NamedSuccessor &successor = op.getSuccessor(i);
builderLines.push_back( builderLines.push_back(formatv(addSuccessorTemplate,
llvm::formatv(addSuccessorTemplate, successor.isVariadic() ? "extend" : "append",
successor.isVariadic() ? "extend" : "append", argName)); argName));
} }
} }
/// Populates `builderLines` with additional lines that are required in the /// Populates `builderLines` with additional lines that are required in the
/// builder to set up op operands. /// builder to set up op operands.
static void static void
populateBuilderLinesOperand(const Operator &op, populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
llvm::ArrayRef<std::string> names, SmallVectorImpl<std::string> &builderLines) {
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
// For each element, find or generate a name. // For each element, find or generate a name.
@ -718,7 +711,7 @@ populateBuilderLinesOperand(const Operator &op,
std::string name = names[i]; std::string name = names[i];
// Choose the formatting string based on the element kind. // Choose the formatting string based on the element kind.
llvm::StringRef formatString; StringRef formatString;
if (!element.isVariableLength()) { if (!element.isVariableLength()) {
formatString = singleOperandAppendTemplate; formatString = singleOperandAppendTemplate;
} else if (element.isOptional()) { } else if (element.isOptional()) {
@ -738,7 +731,7 @@ populateBuilderLinesOperand(const Operator &op,
} }
} }
builderLines.push_back(llvm::formatv(formatString.data(), name)); builderLines.push_back(formatv(formatString.data(), name));
} }
} }
@ -758,7 +751,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// Appends the given multiline string as individual strings into /// Appends the given multiline string as individual strings into
/// `builderLines`. /// `builderLines`.
static void appendLineByLine(StringRef string, static void appendLineByLine(StringRef string,
llvm::SmallVectorImpl<std::string> &builderLines) { SmallVectorImpl<std::string> &builderLines) {
std::pair<StringRef, StringRef> split = std::make_pair(string, string); std::pair<StringRef, StringRef> split = std::make_pair(string, string);
do { do {
@ -770,14 +763,13 @@ static void appendLineByLine(StringRef string,
/// Populates `builderLines` with additional lines that are required in the /// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results. /// builder to set up op results.
static void static void
populateBuilderLinesResult(const Operator &op, populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
llvm::ArrayRef<std::string> names, SmallVectorImpl<std::string> &builderLines) {
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
if (hasSameArgumentAndResultTypes(op)) { if (hasSameArgumentAndResultTypes(op)) {
builderLines.push_back(llvm::formatv( builderLines.push_back(formatv(appendSameResultsTemplate,
appendSameResultsTemplate, "operands[0].type", op.getNumResults())); "operands[0].type", op.getNumResults()));
return; return;
} }
@ -785,10 +777,9 @@ populateBuilderLinesResult(const Operator &op,
const NamedAttribute &firstAttr = op.getAttribute(0); const NamedAttribute &firstAttr = op.getAttribute(0);
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
"from which the type is derived"); "from which the type is derived");
appendLineByLine( appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
builderLines); builderLines);
builderLines.push_back(llvm::formatv(appendSameResultsTemplate, builderLines.push_back(formatv(appendSameResultsTemplate,
"_ods_derived_result_type", "_ods_derived_result_type",
op.getNumResults())); op.getNumResults()));
return; return;
@ -803,7 +794,7 @@ populateBuilderLinesResult(const Operator &op,
std::string name = names[i]; std::string name = names[i];
// Choose the formatting string based on the element kind. // Choose the formatting string based on the element kind.
llvm::StringRef formatString; StringRef formatString;
if (!element.isVariableLength()) { if (!element.isVariableLength()) {
formatString = singleResultAppendTemplate; formatString = singleResultAppendTemplate;
} else if (element.isOptional()) { } else if (element.isOptional()) {
@ -819,17 +810,16 @@ populateBuilderLinesResult(const Operator &op,
} }
} }
builderLines.push_back(llvm::formatv(formatString.data(), name)); builderLines.push_back(formatv(formatString.data(), name));
} }
} }
/// If the operation has variadic regions, adds a builder argument to specify /// If the operation has variadic regions, adds a builder argument to specify
/// the number of those regions and builder lines to forward it to the generic /// the number of those regions and builder lines to forward it to the generic
/// constructor. /// constructor.
static void static void populateBuilderRegions(const Operator &op,
populateBuilderRegions(const Operator &op, SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &builderArgs, SmallVectorImpl<std::string> &builderLines) {
llvm::SmallVectorImpl<std::string> &builderLines) {
if (op.hasNoVariadicRegions()) if (op.hasNoVariadicRegions())
return; return;
@ -844,19 +834,19 @@ populateBuilderRegions(const Operator &op,
.str(); .str();
builderArgs.push_back(name); builderArgs.push_back(name);
builderLines.push_back( builderLines.push_back(
llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
} }
/// Emits a default builder constructing an operation from the list of its /// Emits a default builder constructing an operation from the list of its
/// result types, followed by a list of its operands. Returns vector /// result types, followed by a list of its operands. Returns vector
/// of fully built functionArgs for downstream users (to save having to /// of fully built functionArgs for downstream users (to save having to
/// rebuild anew). /// rebuild anew).
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op, static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) { raw_ostream &os) {
llvm::SmallVector<std::string> builderArgs; SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines; SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames; SmallVector<std::string> operandArgNames;
llvm::SmallVector<std::string> successorArgNames; SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() + builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
op.getNumNativeAttributes() + op.getNumSuccessors()); op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgsResults(op, builderArgs); populateBuilderArgsResults(op, builderArgs);
@ -866,10 +856,10 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames); populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
populateBuilderLinesOperand(op, operandArgNames, builderLines); populateBuilderLinesOperand(op, operandArgNames, builderLines);
populateBuilderLinesAttr( populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines); builderLines);
populateBuilderLinesResult( populateBuilderLinesResult(
op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines); op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
populateBuilderLinesSuccessors(op, successorArgNames, builderLines); populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
populateBuilderRegions(op, builderArgs, builderLines); populateBuilderRegions(op, builderArgs, builderLines);
@ -896,7 +886,7 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
}; };
// StringRefs in functionArgs refer to strings allocated by builderArgs. // StringRefs in functionArgs refer to strings allocated by builderArgs.
llvm::SmallVector<llvm::StringRef> functionArgs; SmallVector<StringRef> functionArgs;
// Add positional arguments. // Add positional arguments.
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) { for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
@ -929,11 +919,10 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
initArgs.push_back("loc=loc"); initArgs.push_back("loc=loc");
initArgs.push_back("ip=ip"); initArgs.push_back("ip=ip");
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), os << formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "), llvm::join(builderLines, "\n "), llvm::join(initArgs, ", "));
llvm::join(initArgs, ", "));
return llvm::to_vector<8>( return llvm::to_vector<8>(
llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); })); llvm::map_range(functionArgs, [](StringRef s) { return s.str(); }));
} }
static void emitSegmentSpec( static void emitSegmentSpec(
@ -955,14 +944,14 @@ static void emitSegmentSpec(
} }
segmentSpec.append("]"); segmentSpec.append("]");
os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
} }
static void emitRegionAttributes(const Operator &op, raw_ostream &os) { static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
// Note that the base OpView class defines this as (0, True). // Note that the base OpView class defines this as (0, True).
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, os << formatv(opClassRegionSpecTemplate, minRegionCount,
op.hasNoVariadicRegions() ? "True" : "False"); op.hasNoVariadicRegions() ? "True" : "False");
} }
@ -975,7 +964,7 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
"expected only the last region to be variadic"); "expected only the last region to be variadic");
os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), os << formatv(regionAccessorTemplate, sanitizeName(region.name),
std::to_string(en.index()) + std::to_string(en.index()) +
(region.isVariadic() ? ":" : "")); (region.isVariadic() ? ":" : ""));
} }
@ -983,12 +972,12 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
/// Emits builder that extracts results from op /// Emits builder that extracts results from op
static void emitValueBuilder(const Operator &op, static void emitValueBuilder(const Operator &op,
llvm::SmallVector<std::string> functionArgs, SmallVector<std::string> functionArgs,
raw_ostream &os) { raw_ostream &os) {
// Params with (possibly) default args. // Params with (possibly) default args.
auto valueBuilderParams = auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) { llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
llvm::SmallVector<llvm::StringRef> argMaybeDefault = SmallVector<StringRef> argMaybeDefault =
llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "=")); llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]); auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
if (argMaybeDefault.size() == 2) if (argMaybeDefault.size() == 2)
@ -1005,7 +994,7 @@ static void emitValueBuilder(const Operator &op,
}); });
std::string nameWithoutDialect = std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1); op.getOperationName().substr(op.getOperationName().find('.') + 1);
os << llvm::formatv( os << formatv(
valueBuilderTemplate, sanitizeName(nameWithoutDialect), valueBuilderTemplate, sanitizeName(nameWithoutDialect),
op.getCppClassName(), llvm::join(valueBuilderParams, ", "), op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "), llvm::join(opBuilderArgs, ", "),
@ -1016,8 +1005,7 @@ static void emitValueBuilder(const Operator &op,
/// Emits bindings for a specific Op to the given output stream. /// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) { static void emitOpBindings(const Operator &op, raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(), os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName());
op.getOperationName());
// Sized segments. // Sized segments.
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
@ -1028,7 +1016,7 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
} }
emitRegionAttributes(op, os); emitRegionAttributes(op, os);
llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os); SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os); emitOperandAccessors(op, os);
emitAttributeAccessors(op, os); emitAttributeAccessors(op, os);
emitResultAccessors(op, os); emitResultAccessors(op, os);
@ -1039,17 +1027,17 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
/// Emits bindings for the dialect specified in the command line, including file /// Emits bindings for the dialect specified in the command line, including file
/// headers and utilities. Returns `false` on success to comply with Tablegen /// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements. /// registration requirements.
static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty()) if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided"); llvm::PrintFatalError("dialect name not provided");
os << fileHeader; os << fileHeader;
if (!clDialectExtensionName.empty()) if (!clDialectExtensionName.empty())
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue()); os << formatv(dialectExtensionTemplate, clDialectName.getValue());
else else
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); os << formatv(dialectClassTemplate, clDialectName.getValue());
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec); Operator op(rec);
if (op.getDialectName() == clDialectName.getValue()) if (op.getDialectName() == clDialectName.getValue())
emitOpBindings(op, os); emitOpBindings(op, os);

View File

@ -20,6 +20,8 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::formatv;
using llvm::RecordKeeper;
static llvm::cl::OptionCategory static llvm::cl::OptionCategory
passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl"); passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl");
@ -56,7 +58,7 @@ const char *const fileFooter = R"(
)"; )";
/// Emit TODO /// Emit TODO
static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) { static bool emitCAPIHeader(const RecordKeeper &records, raw_ostream &os) {
os << fileHeader; os << fileHeader;
os << "// Registration for the entire group\n"; os << "// Registration for the entire group\n";
os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName
@ -64,7 +66,7 @@ static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
Pass pass(def); Pass pass(def);
StringRef defName = pass.getDef()->getName(); StringRef defName = pass.getDef()->getName();
os << llvm::formatv(passDecl, groupName, defName); os << formatv(passDecl, groupName, defName);
} }
os << fileFooter; os << fileFooter;
return false; return false;
@ -91,9 +93,9 @@ void mlirRegister{0}Passes(void) {{
} }
)"; )";
static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) { static bool emitCAPIImpl(const RecordKeeper &records, raw_ostream &os) {
os << "/* Autogenerated by mlir-tblgen; don't manually edit. */"; os << "/* Autogenerated by mlir-tblgen; don't manually edit. */";
os << llvm::formatv(passGroupRegistrationCode, groupName); os << formatv(passGroupRegistrationCode, groupName);
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
Pass pass(def); Pass pass(def);
@ -103,10 +105,9 @@ static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
if (StringRef constructor = pass.getConstructor(); !constructor.empty()) if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str(); constructorCall = constructor.str();
else else
constructorCall = constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
os << llvm::formatv(passCreateDef, groupName, defName, constructorCall); os << formatv(passCreateDef, groupName, defName, constructorCall);
} }
return false; return false;
} }

View File

@ -18,6 +18,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::RecordKeeper;
/// Emit the documentation for the given pass. /// Emit the documentation for the given pass.
static void emitDoc(const Pass &pass, raw_ostream &os) { static void emitDoc(const Pass &pass, raw_ostream &os) {
@ -56,7 +57,7 @@ static void emitDoc(const Pass &pass, raw_ostream &os) {
} }
} }
static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { static void emitDocs(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n"; os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
auto passDefs = recordKeeper.getAllDerivedDefinitions("PassBase"); auto passDefs = recordKeeper.getAllDerivedDefinitions("PassBase");
@ -74,7 +75,7 @@ static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
static mlir::GenRegistration static mlir::GenRegistration
genRegister("gen-pass-doc", "Generate pass documentation", genRegister("gen-pass-doc", "Generate pass documentation",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
emitDocs(records, os); emitDocs(records, os);
return false; return false;
}); });

View File

@ -21,6 +21,8 @@
using namespace mlir; using namespace mlir;
using namespace mlir::tblgen; using namespace mlir::tblgen;
using llvm::formatv;
using llvm::RecordKeeper;
static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls"); static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
static llvm::cl::opt<std::string> static llvm::cl::opt<std::string>
@ -28,7 +30,7 @@ static llvm::cl::opt<std::string>
llvm::cl::cat(passGenCat)); llvm::cl::cat(passGenCat));
/// Extract the list of passes from the TableGen records. /// Extract the list of passes from the TableGen records.
static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) { static std::vector<Pass> getPasses(const RecordKeeper &recordKeeper) {
std::vector<Pass> passes; std::vector<Pass> passes;
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase")) for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
@ -91,7 +93,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
if (options.empty()) if (options.empty())
return; return;
os << llvm::formatv("struct {0}Options {{\n", passName); os << formatv("struct {0}Options {{\n", passName);
for (const PassOption &opt : options) { for (const PassOption &opt : options) {
std::string type = opt.getType().str(); std::string type = opt.getType().str();
@ -99,7 +101,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
if (opt.isListOption()) if (opt.isListOption())
type = "::llvm::SmallVector<" + type + ">"; type = "::llvm::SmallVector<" + type + ">";
os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName()); os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName());
if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
os << " = " << defaultVal; os << " = " << defaultVal;
@ -128,7 +130,7 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
// Declaration of the constructor with options. // Declaration of the constructor with options.
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(" os << formatv("std::unique_ptr<::mlir::Pass> create{0}("
"{0}Options options);\n", "{0}Options options);\n",
passName); passName);
} }
@ -147,14 +149,13 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
if (StringRef constructor = pass.getConstructor(); !constructor.empty()) if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str(); constructorCall = constructor.str();
else else
constructorCall = constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), os << formatv(passRegistrationCode, pass.getDef()->getName(),
constructorCall); constructorCall);
} }
os << llvm::formatv(passGroupRegistrationCode, groupName); os << formatv(passGroupRegistrationCode, groupName);
for (const Pass &pass : passes) for (const Pass &pass : passes)
os << " register" << pass.getDef()->getName() << "();\n"; os << " register" << pass.getDef()->getName() << "();\n";
@ -270,9 +271,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
os.indent(2) << "::mlir::Pass::" os.indent(2) << "::mlir::Pass::"
<< (opt.isListOption() ? "ListOption" : "Option"); << (opt.isListOption() ? "ListOption" : "Option");
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))", os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
opt.getType(), opt.getCppVariableName(), opt.getType(), opt.getCppVariableName(), opt.getArgument(),
opt.getArgument(), opt.getDescription()); opt.getDescription());
if (std::optional<StringRef> defaultVal = opt.getDefaultValue()) if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
os << ", ::llvm::cl::init(" << defaultVal << ")"; os << ", ::llvm::cl::init(" << defaultVal << ")";
if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags()) if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
@ -284,9 +285,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
/// Emit the declarations for each of the pass statistics. /// Emit the declarations for each of the pass statistics.
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
for (const PassStatistic &stat : pass.getStatistics()) { for (const PassStatistic &stat : pass.getStatistics()) {
os << llvm::formatv( os << formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n", stat.getCppVariableName(), stat.getName(),
stat.getCppVariableName(), stat.getName(), stat.getDescription()); stat.getDescription());
} }
} }
@ -300,11 +301,10 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
os << "#ifdef " << enableVarName << "\n"; os << "#ifdef " << enableVarName << "\n";
if (emitDefaultConstructors) { if (emitDefaultConstructors) {
os << llvm::formatv(friendDefaultConstructorDeclTemplate, passName); os << formatv(friendDefaultConstructorDeclTemplate, passName);
if (emitDefaultConstructorWithOptions) if (emitDefaultConstructorWithOptions)
os << llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate, os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName);
passName);
} }
std::string dependentDialectRegistrations; std::string dependentDialectRegistrations;
@ -313,23 +313,22 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
llvm::interleave( llvm::interleave(
pass.getDependentDialects(), dialectsOs, pass.getDependentDialects(), dialectsOs,
[&](StringRef dependentDialect) { [&](StringRef dependentDialect) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate, dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
dependentDialect);
}, },
"\n "); "\n ");
} }
os << "namespace impl {\n"; os << "namespace impl {\n";
os << llvm::formatv(baseClassBegin, passName, pass.getBaseClass(), os << formatv(baseClassBegin, passName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(), pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations); dependentDialectRegistrations);
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) { if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
os.indent(2) << llvm::formatv( os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n",
"{0}Base({0}Options options) : {0}Base() {{\n", passName); passName);
for (const PassOption &opt : pass.getOptions()) for (const PassOption &opt : pass.getOptions())
os.indent(4) << llvm::formatv("{0} = std::move(options.{0});\n", os.indent(4) << formatv("{0} = std::move(options.{0});\n",
opt.getCppVariableName()); opt.getCppVariableName());
os.indent(2) << "}\n"; os.indent(2) << "}\n";
@ -344,21 +343,20 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
os << "private:\n"; os << "private:\n";
if (emitDefaultConstructors) { if (emitDefaultConstructors) {
os << llvm::formatv(friendDefaultConstructorDefTemplate, passName); os << formatv(friendDefaultConstructorDefTemplate, passName);
if (!pass.getOptions().empty()) if (!pass.getOptions().empty())
os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate, os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName);
passName);
} }
os << "};\n"; os << "};\n";
os << "} // namespace impl\n"; os << "} // namespace impl\n";
if (emitDefaultConstructors) { if (emitDefaultConstructors) {
os << llvm::formatv(defaultConstructorDefTemplate, passName); os << formatv(defaultConstructorDefTemplate, passName);
if (emitDefaultConstructorWithOptions) if (emitDefaultConstructorWithOptions)
os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName); os << formatv(defaultConstructorWithOptionsDefTemplate, passName);
} }
os << "#undef " << enableVarName << "\n"; os << "#undef " << enableVarName << "\n";
@ -367,7 +365,7 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
static void emitPass(const Pass &pass, raw_ostream &os) { static void emitPass(const Pass &pass, raw_ostream &os) {
StringRef passName = pass.getDef()->getName(); StringRef passName = pass.getDef()->getName();
os << llvm::formatv(passHeader, passName); os << formatv(passHeader, passName);
emitPassDecls(pass, os); emitPassDecls(pass, os);
emitPassDefs(pass, os); emitPassDefs(pass, os);
@ -436,12 +434,11 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
llvm::interleave( llvm::interleave(
pass.getDependentDialects(), dialectsOs, pass.getDependentDialects(), dialectsOs,
[&](StringRef dependentDialect) { [&](StringRef dependentDialect) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate, dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
dependentDialect);
}, },
"\n "); "\n ");
} }
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(), os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(), pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations); dependentDialectRegistrations);
emitPassOptionDecls(pass, os); emitPassOptionDecls(pass, os);
@ -449,8 +446,7 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
os << "};\n"; os << "};\n";
} }
static void emitPasses(const llvm::RecordKeeper &recordKeeper, static void emitPasses(const RecordKeeper &recordKeeper, raw_ostream &os) {
raw_ostream &os) {
std::vector<Pass> passes = getPasses(recordKeeper); std::vector<Pass> passes = getPasses(recordKeeper);
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
@ -479,7 +475,7 @@ static void emitPasses(const llvm::RecordKeeper &recordKeeper,
static mlir::GenRegistration static mlir::GenRegistration
genPassDecls("gen-pass-decls", "Generate pass declarations", genPassDecls("gen-pass-decls", "Generate pass declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) { [](const RecordKeeper &records, raw_ostream &os) {
emitPasses(records, os); emitPasses(records, os);
return false; return false;
}); });

View File

@ -98,15 +98,15 @@ public:
StringRef getMergeInstance() const; StringRef getMergeInstance() const;
// Returns the underlying LLVM TableGen Record. // Returns the underlying LLVM TableGen Record.
const llvm::Record *getDef() const { return def; } const Record *getDef() const { return def; }
private: private:
// The TableGen definition of this availability. // The TableGen definition of this availability.
const llvm::Record *def; const Record *def;
}; };
} // namespace } // namespace
Availability::Availability(const llvm::Record *def) : def(def) { Availability::Availability(const Record *def) : def(def) {
assert(def->isSubClassOf("Availability") && assert(def->isSubClassOf("Availability") &&
"must be subclass of TableGen 'Availability' class"); "must be subclass of TableGen 'Availability' class");
} }