[NFC][MLIR][TableGen] Eliminate llvm:: for commonly used types (#110841)

Eliminate `llvm::` namespace qualifier for commonly used types in MLIR
TableGen backends to reduce code clutter.
This commit is contained in:
Rahul Joshi 2024-10-02 13:23:44 -07:00 committed by GitHub
parent 906ffc4b4a
commit bccd37f69f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 450 additions and 475 deletions

View File

@ -22,6 +22,8 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::Record;
using llvm::RecordKeeper;
//===----------------------------------------------------------------------===//
// Utility Functions
@ -30,14 +32,14 @@ using namespace mlir::tblgen;
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
/// specified and can only find one dialect's defs, use that.
static void collectAllDefs(StringRef selectedDialect,
ArrayRef<const llvm::Record *> records,
ArrayRef<const Record *> records,
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
// Nothing to do if no defs were found.
if (records.empty())
return;
auto defs = llvm::map_range(
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
records, [&](const Record *rec) { return AttrOrTypeDef(rec); });
if (selectedDialect.empty()) {
// If a dialect was not specified, ensure that all found defs belong to the
// same dialect.
@ -690,15 +692,14 @@ public:
bool emitDefs(StringRef selectedDialect);
protected:
DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os,
DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
StringRef defType, StringRef valueType, bool isAttrGenerator)
: defRecords(defs), os(os), defType(defType), valueType(valueType),
isAttrGenerator(isAttrGenerator) {
// Sort by occurrence in file.
llvm::sort(defRecords,
[](const llvm::Record *lhs, const llvm::Record *rhs) {
return lhs->getID() < rhs->getID();
});
llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID();
});
}
/// Emit the list of def type names.
@ -707,7 +708,7 @@ protected:
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
/// The set of def records to emit.
std::vector<const llvm::Record *> defRecords;
std::vector<const Record *> defRecords;
/// The attribute or type class to emit.
/// The stream to emit to.
raw_ostream &os;
@ -722,13 +723,13 @@ protected:
/// A specialized generator for AttrDefs.
struct AttrDefGenerator : public DefGenerator {
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
};
/// A specialized generator for TypeDefs.
struct TypeDefGenerator : public DefGenerator {
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
"Type", "Type", /*isAttrGenerator=*/false) {}
};
@ -1030,9 +1031,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
/// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint>
getAllTypeConstraints(const llvm::RecordKeeper &records) {
getAllTypeConstraints(const RecordKeeper &records) {
std::vector<Constraint> result;
for (const llvm::Record *def :
for (const Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
// Ignore constraints defined outside of the top-level file.
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
@ -1047,7 +1048,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
return result;
}
static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
static void emitTypeConstraintDecls(const RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type);
@ -1057,7 +1058,7 @@ bool {0}(::mlir::Type type);
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
}
static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
static void emitTypeConstraintDefs(const RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) {
@ -1088,13 +1089,13 @@ static llvm::cl::opt<std::string>
static mlir::GenRegistration
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDefs(attrDialect);
});
static mlir::GenRegistration
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
AttrDefGenerator generator(records, os);
return generator.emitDecls(attrDialect);
});
@ -1110,13 +1111,13 @@ static llvm::cl::opt<std::string>
static mlir::GenRegistration
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDefs(typeDialect);
});
static mlir::GenRegistration
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});
@ -1124,14 +1125,14 @@ static mlir::GenRegistration
static mlir::GenRegistration
genTypeConstrDefs("gen-type-constraint-defs",
"Generate type constraint definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDefs(records, os);
return false;
});
static mlir::GenRegistration
genTypeConstrDecls("gen-type-constraint-decls",
"Generate type constraint declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDecls(records, os);
return false;
});

View File

@ -18,11 +18,10 @@
using namespace llvm;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
static llvm::cl::opt<std::string>
selectedBcDialect("bytecode-dialect",
llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
static cl::opt<std::string>
selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
cl::cat(dialectGenCat), cl::CommaSeparated);
namespace {
@ -306,7 +305,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
auto funScope = os.scope("{\n", "}\n\n");
// Check that predicates specified if multiple bytecode instances.
for (const llvm::Record *rec : make_second_range(vec)) {
for (const Record *rec : make_second_range(vec)) {
StringRef pred = rec->getValueAsString("printerPredicate");
if (vec.size() > 1 && pred.empty()) {
for (auto [index, rec] : vec) {

View File

@ -30,6 +30,8 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::Record;
using llvm::RecordKeeper;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
llvm::cl::opt<std::string>
@ -39,8 +41,8 @@ llvm::cl::opt<std::string>
/// Utility iterator used for filtering records for a specific dialect.
namespace {
using DialectFilterIterator =
llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
std::function<bool(const llvm::Record *)>>;
llvm::filter_iterator<ArrayRef<Record *>::iterator,
std::function<bool(const Record *)>>;
} // namespace
static void populateDiscardableAttributes(
@ -62,8 +64,8 @@ static void populateDiscardableAttributes(
/// the given dialect.
template <typename T>
static iterator_range<DialectFilterIterator>
filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
auto filterFn = [&](const llvm::Record *record) {
filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
auto filterFn = [&](const Record *record) {
return T(record).getDialect() == dialect;
};
return {DialectFilterIterator(records.begin(), records.end(), filterFn),
@ -295,7 +297,7 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
<< "::" << dialect.getCppClassName() << ")\n";
}
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
static bool emitDialectDecls(const RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("Dialect Declarations", os, recordKeeper);
@ -340,8 +342,7 @@ static const char *const dialectDestructorStr = R"(
)";
static void emitDialectDef(Dialect &dialect,
const llvm::RecordKeeper &recordKeeper,
static void emitDialectDef(Dialect &dialect, const RecordKeeper &recordKeeper,
raw_ostream &os) {
std::string cppClassName = dialect.getCppClassName();
@ -389,8 +390,7 @@ static void emitDialectDef(Dialect &dialect,
os << llvm::formatv(dialectDestructorStr, cppClassName);
}
static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
static bool emitDialectDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
emitSourceFileHeader("Dialect Definitions", os, recordKeeper);
auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
@ -411,12 +411,12 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
static mlir::GenRegistration
genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
return emitDialectDecls(records, os);
});
static mlir::GenRegistration
genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
return emitDialectDefs(records, os);
});

View File

@ -21,6 +21,9 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
/// File header and includes.
constexpr const char *fileHeader = R"Py(
@ -42,44 +45,42 @@ static std::string makePythonEnumCaseName(StringRef name) {
/// Emits the Python class for the given enum.
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
if (!enumAttr.getSummary().empty())
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
os << "\n";
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
os << llvm::formatv(
" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
: "auto()");
os << formatv(" {0} = {1}\n",
makePythonEnumCaseName(enumCase.getSymbol()),
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
: "auto()");
}
os << "\n";
if (enumAttr.isBitEnum()) {
os << llvm::formatv(" def __iter__(self):\n"
" return iter([case for case in type(self) if "
"(self & case) is case])\n");
os << llvm::formatv(" def __len__(self):\n"
" return bin(self).count(\"1\")\n");
os << formatv(" def __iter__(self):\n"
" return iter([case for case in type(self) if "
"(self & case) is case])\n");
os << formatv(" def __len__(self):\n"
" return bin(self).count(\"1\")\n");
os << "\n";
}
os << llvm::formatv(" def __str__(self):\n");
os << formatv(" def __str__(self):\n");
if (enumAttr.isBitEnum())
os << llvm::formatv(" if len(self) > 1:\n"
" return \"{0}\".join(map(str, self))\n",
enumAttr.getDef().getValueAsString("separator"));
os << formatv(" if len(self) > 1:\n"
" return \"{0}\".join(map(str, self))\n",
enumAttr.getDef().getValueAsString("separator"));
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
os << llvm::formatv(" if self is {0}.{1}:\n",
enumAttr.getEnumClassName(),
makePythonEnumCaseName(enumCase.getSymbol()));
os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr());
os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
makePythonEnumCaseName(enumCase.getSymbol()));
os << formatv(" return \"{0}\"\n", enumCase.getStr());
}
os << llvm::formatv(
" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
enumAttr.getEnumClassName());
os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
enumAttr.getEnumClassName());
os << "\n";
}
@ -105,15 +106,13 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
return true;
}
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
enumAttr.getAttrDefName());
os << llvm::formatv("def _{0}(x, context):\n",
enumAttr.getAttrDefName().lower());
os << llvm::formatv(
" return "
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
"context=context), int(x))\n\n",
bitwidth);
os << formatv("@register_attribute_builder(\"{0}\")\n",
enumAttr.getAttrDefName());
os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
os << formatv(" return "
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
"context=context), int(x))\n\n",
bitwidth);
return false;
}
@ -123,26 +122,25 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
StringRef formatString,
raw_ostream &os) {
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
os << llvm::formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
os << formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
formatString);
return false;
}
/// Emits Python bindings for all enums in the record keeper. Returns
/// `false` on success, `true` on failure.
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
static bool emitPythonEnums(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << fileHeader;
for (const llvm::Record *it :
for (const Record *it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
EnumAttr enumAttr(*it);
emitEnumClass(enumAttr, os);
emitAttributeBuilder(enumAttr, os);
}
for (const llvm::Record *it :
for (const Record *it :
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
if (!attr.getMnemonic()) {
@ -156,11 +154,11 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder(
attr.getName(),
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
} else if (assemblyFormat == "$value") {
emitDialectEnumAttributeBuilder(
attr.getName(),
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
} else {
llvm::errs()
<< "unsupported assembly format for python enum bindings generation";

View File

@ -26,10 +26,9 @@
using llvm::formatv;
using llvm::isDigit;
using llvm::PrintFatalError;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
using namespace mlir;
using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::EnumAttrCase;
@ -139,7 +138,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
// is not a power of two (i.e. not a single bit case) and not a known case.
} else if (enumAttr.isBitEnum()) {
// Process the known multi-bit cases that use valid keywords.
llvm::SmallVector<EnumAttrCase *> validMultiBitCases;
SmallVector<EnumAttrCase *> validMultiBitCases;
for (auto [index, caseVal] : llvm::enumerate(cases)) {
uint64_t value = caseVal.getValue();
if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
@ -476,7 +475,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
const llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
const Record *baseAttrDef = enumAttr.getBaseAttrClass();
Attribute baseAttr(baseAttrDef);
// Emit classof method
@ -565,7 +564,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
llvm::SmallVector<StringRef, 2> namespaces;
SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
@ -656,7 +655,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef cppNamespace = enumAttr.getCppNamespace();
llvm::SmallVector<StringRef, 2> namespaces;
SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)

View File

@ -13,6 +13,7 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::SourceMgr;
//===----------------------------------------------------------------------===//
// FormatToken
@ -26,14 +27,14 @@ SMLoc FormatToken::getLoc() const {
// FormatLexer
//===----------------------------------------------------------------------===//
FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc)
FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc)
: mgr(mgr), loc(loc),
curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
curPtr(curBuffer.begin()) {}
FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
"in custom assembly format for this operation");
return formToken(FormatToken::error, loc.getPointer());
}
@ -44,10 +45,10 @@ FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
const Twine &note) {
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
"in custom assembly format for this operation");
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
mgr.PrintMessage(loc, SourceMgr::DK_Note, note);
return formToken(FormatToken::error, loc.getPointer());
}

View File

@ -411,8 +411,7 @@ public:
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API enumerant
static void emitOneEnumToConversion(const llvm::Record *record,
raw_ostream &os) {
static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
@ -441,8 +440,7 @@ static void emitOneEnumToConversion(const llvm::Record *record,
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API C-style enumerant
static void emitOneCEnumToConversion(const llvm::Record *record,
raw_ostream &os) {
static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
LLVMCEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
@ -472,8 +470,7 @@ static void emitOneCEnumToConversion(const llvm::Record *record,
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
// LLVM dialect enum attribute (Enum).
static void emitOneEnumFromConversion(const llvm::Record *record,
raw_ostream &os) {
static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
@ -508,8 +505,7 @@ static void emitOneEnumFromConversion(const llvm::Record *record,
// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
// containing switch-based logic to convert from the LLVM API C-style enumerant
// to MLIR LLVM dialect enum attribute (Enum).
static void emitOneCEnumFromConversion(const llvm::Record *record,
raw_ostream &os) {
static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
LLVMCEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();

View File

@ -24,6 +24,11 @@
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using llvm::Record;
using llvm::RecordKeeper;
using llvm::Regex;
using namespace mlir;
static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options");
static llvm::cl::opt<std::string>
@ -54,14 +59,14 @@ static llvm::cl::opt<std::string> aliasAnalysisRegexp(
using IndicesTy = llvm::SmallBitVector;
/// Return a CodeGen value type entry from a type record.
static llvm::MVT::SimpleValueType getValueType(const llvm::Record *rec) {
static llvm::MVT::SimpleValueType getValueType(const Record *rec) {
return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
"Value");
}
/// Return the indices of the definitions in a list of definitions that
/// represent overloadable types
static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record,
static IndicesTy getOverloadableTypeIdxs(const Record &record,
const char *listName) {
auto results = record.getValueAsListOfDefs(listName);
IndicesTy overloadedOps(results.size());
@ -87,13 +92,13 @@ namespace {
/// the fields of the record.
class LLVMIntrinsic {
public:
LLVMIntrinsic(const llvm::Record &record) : record(record) {}
LLVMIntrinsic(const Record &record) : record(record) {}
/// Get the name of the operation to be used in MLIR. Uses the appropriate
/// field if not empty, constructs a name by replacing underscores with dots
/// in the record name otherwise.
std::string getOperationName() const {
llvm::StringRef name = record.getValueAsString(fieldName);
StringRef name = record.getValueAsString(fieldName);
if (!name.empty())
return name.str();
@ -101,8 +106,8 @@ public:
assert(name.starts_with("int_") &&
"LLVM intrinsic names are expected to start with 'int_'");
name = name.drop_front(4);
llvm::SmallVector<llvm::StringRef, 8> chunks;
llvm::StringRef targetPrefix = record.getValueAsString("TargetPrefix");
SmallVector<StringRef, 8> chunks;
StringRef targetPrefix = record.getValueAsString("TargetPrefix");
name.split(chunks, '_');
auto *chunksBegin = chunks.begin();
// Remove the target prefix from target specific intrinsics.
@ -119,8 +124,8 @@ public:
}
/// Get the name of the record without the "intrinsic" prefix.
llvm::StringRef getProperRecordName() const {
llvm::StringRef name = record.getName();
StringRef getProperRecordName() const {
StringRef name = record.getName();
assert(name.starts_with("int_") &&
"LLVM intrinsic names are expected to start with 'int_'");
return name.drop_front(4);
@ -129,10 +134,9 @@ public:
/// Get the number of operands.
unsigned getNumOperands() const {
auto operands = record.getValueAsListOfDefs(fieldOperands);
assert(llvm::all_of(operands,
[](const llvm::Record *r) {
return r->isSubClassOf("LLVMType");
}) &&
assert(llvm::all_of(
operands,
[](const Record *r) { return r->isSubClassOf("LLVMType"); }) &&
"expected operands to be of LLVM type");
return operands.size();
}
@ -142,7 +146,7 @@ public:
/// structure type.
unsigned getNumResults() const {
auto results = record.getValueAsListOfDefs(fieldResults);
for (const llvm::Record *r : results) {
for (const Record *r : results) {
(void)r;
assert(r->isSubClassOf("LLVMType") &&
"expected operands to be of LLVM type");
@ -155,7 +159,7 @@ public:
bool hasSideEffects() const {
return llvm::none_of(
record.getValueAsListOfDefs(fieldTraits),
[](const llvm::Record *r) { return r->getName() == "IntrNoMem"; });
[](const Record *r) { return r->getName() == "IntrNoMem"; });
}
/// Return true if the intrinsic is commutative, i.e. has the respective
@ -163,7 +167,7 @@ public:
bool isCommutative() const {
return llvm::any_of(
record.getValueAsListOfDefs(fieldTraits),
[](const llvm::Record *r) { return r->getName() == "Commutative"; });
[](const Record *r) { return r->getName() == "Commutative"; });
}
IndicesTy getOverloadableOperandsIdxs() const {
@ -181,7 +185,7 @@ private:
const char *fieldResults = "RetTypes";
const char *fieldTraits = "IntrProperties";
const llvm::Record &record;
const Record &record;
};
} // namespace
@ -195,27 +199,26 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
/// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic.
/// Returns true on error, false on success.
static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
LLVMIntrinsic intr(record);
llvm::Regex accessGroupMatcher(accessGroupRegexp);
Regex accessGroupMatcher(accessGroupRegexp);
bool requiresAccessGroup =
!accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
llvm::Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() &&
aliasAnalysisMatcher.match(record.getName());
// Prepare strings for traits, if any.
llvm::SmallVector<llvm::StringRef, 2> traits;
SmallVector<StringRef, 2> traits;
if (intr.isCommutative())
traits.push_back("Commutative");
if (!intr.hasSideEffects())
traits.push_back("NoMemoryEffect");
// Prepare strings for operands.
llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(),
"LLVM_Type");
SmallVector<StringRef, 8> operands(intr.getNumOperands(), "LLVM_Type");
if (requiresAccessGroup)
operands.push_back(
"OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups");
@ -247,14 +250,13 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
/// Traverses the list of TableGen definitions derived from the "Intrinsic"
/// class and generates MLIR ODS definitions for those intrinsics that have
/// the name matching the filter.
static bool emitIntrinsics(const llvm::RecordKeeper &records,
llvm::raw_ostream &os) {
static bool emitIntrinsics(const RecordKeeper &records, llvm::raw_ostream &os) {
llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records);
os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n";
os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n";
auto defs = records.getAllDerivedDefinitions("Intrinsic");
for (const llvm::Record *r : defs) {
for (const Record *r : defs) {
if (!nameFilter.empty() && !r->getName().contains(nameFilter))
continue;
if (emitIntrinsic(*r, os))

View File

@ -34,30 +34,30 @@
#include <set>
#include <string>
//===----------------------------------------------------------------------===//
// Commandline Options
//===----------------------------------------------------------------------===//
static llvm::cl::OptionCategory
docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
llvm::cl::opt<std::string>
stripPrefix("strip-prefix",
llvm::cl::desc("Strip prefix of the fully qualified names"),
llvm::cl::init("::mlir::"), llvm::cl::cat(docCat));
llvm::cl::opt<bool> allowHugoSpecificFeatures(
"allow-hugo-specific-features",
llvm::cl::desc("Allows using features specific to Hugo"),
llvm::cl::init(false), llvm::cl::cat(docCat));
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
using mlir::tblgen::Operator;
//===----------------------------------------------------------------------===//
// Commandline Options
//===----------------------------------------------------------------------===//
static cl::OptionCategory
docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
cl::opt<std::string>
stripPrefix("strip-prefix",
cl::desc("Strip prefix of the fully qualified names"),
cl::init("::mlir::"), cl::cat(docCat));
cl::opt<bool> allowHugoSpecificFeatures(
"allow-hugo-specific-features",
cl::desc("Allows using features specific to Hugo"), cl::init(false),
cl::cat(docCat));
void mlir::tblgen::emitSummary(StringRef summary, raw_ostream &os) {
if (!summary.empty()) {
llvm::StringRef trimmed = summary.trim();
StringRef trimmed = summary.trim();
char first = std::toupper(trimmed.front());
llvm::StringRef rest = trimmed.drop_front();
StringRef rest = trimmed.drop_front();
os << "\n_" << first << rest << "_\n\n";
}
}
@ -152,10 +152,10 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
effectName.consume_front("::");
effectName.consume_front("mlir::");
std::string effectStr;
llvm::raw_string_ostream os(effectStr);
raw_string_ostream os(effectStr);
os << effectName << "{";
auto list = trait.getDef().getValueAsListOfDefs("effects");
llvm::interleaveComma(list, os, [&](const Record *rec) {
interleaveComma(list, os, [&](const Record *rec) {
StringRef effect = rec->getValueAsString("effect");
effect.consume_front("::");
effect.consume_front("mlir::");
@ -163,7 +163,7 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
});
os << "}";
effects.insert(backticks(effectStr));
name.append(llvm::formatv(" ({0})", traitName).str());
name.append(formatv(" ({0})", traitName).str());
}
interfaces.insert(backticks(name));
continue;
@ -172,15 +172,15 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
traits.insert(backticks(name));
}
if (!traits.empty()) {
llvm::interleaveComma(traits, os << "\nTraits: ");
interleaveComma(traits, os << "\nTraits: ");
os << "\n";
}
if (!interfaces.empty()) {
llvm::interleaveComma(interfaces, os << "\nInterfaces: ");
interleaveComma(interfaces, os << "\nInterfaces: ");
os << "\n";
}
if (!effects.empty()) {
llvm::interleaveComma(effects, os << "\nEffects: ");
interleaveComma(effects, os << "\nEffects: ");
os << "\n";
}
}
@ -196,7 +196,7 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) {
std::string classNameStr = op.getQualCppClassName();
StringRef className = classNameStr;
(void)className.consume_front(stripPrefix);
os << llvm::formatv("### `{0}` ({1})\n", op.getOperationName(), className);
os << formatv("### `{0}` ({1})\n", op.getOperationName(), className);
// Emit the summary, syntax, and description if present.
if (op.hasSummary())
@ -287,7 +287,7 @@ static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
emitSourceLink(recordKeeper.getInputFilename(), os);
for (const llvm::Record *opDef : opDefs)
for (const Record *opDef : opDefs)
emitOpDoc(Operator(opDef), os);
}
@ -339,7 +339,7 @@ static void emitAttrOrTypeDefAssemblyFormat(const AttrOrTypeDef &def,
}
static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
os << llvm::formatv("### {0}\n", def.getCppClassName());
os << formatv("### {0}\n", def.getCppClassName());
// Emit the summary if present.
if (def.hasSummary())
@ -376,7 +376,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
auto defs = recordKeeper.getAllDerivedDefinitions(recordTypeName);
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const llvm::Record *def : defs)
for (const Record *def : defs)
emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os);
}
@ -385,7 +385,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
//===----------------------------------------------------------------------===//
static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
os << llvm::formatv("### {0}\n", def.getEnumClassName());
os << formatv("### {0}\n", def.getEnumClassName());
// Emit the summary if present.
if (!def.getSummary().empty())
@ -406,8 +406,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const llvm::Record *def :
recordKeeper.getAllDerivedDefinitions("EnumAttr"))
for (const Record *def : recordKeeper.getAllDerivedDefinitions("EnumAttr"))
emitEnumDoc(EnumAttr(def), os);
}
@ -431,7 +430,7 @@ struct OpDocGroup {
static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
raw_ostream &os) {
std::string str;
llvm::raw_string_ostream ss(str);
raw_string_ostream ss(str);
fn(ss);
for (StringRef x : llvm::split(str, "\n")) {
if (nest && x.starts_with("#"))
@ -507,7 +506,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
emitIfNotEmpty(dialect.getDescription(), os);
// Generate a TOC marker except if description already contains one.
llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline);
Regex r("^[[:space:]]*\\[TOC\\]$", Regex::RegexFlags::Newline);
if (!r.match(dialect.getDescription()))
os << "[TOC]\n\n";
@ -537,17 +536,15 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<TypeDef> dialectTypeDefs;
std::vector<EnumAttr> dialectEnums;
llvm::SmallDenseSet<const Record *> seen;
auto addIfNotSeen = [&](const llvm::Record *record, const auto &def,
auto &vec) {
SmallDenseSet<const Record *> seen;
auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) {
if (seen.insert(record).second) {
vec.push_back(def);
return true;
}
return false;
};
auto addIfInDialect = [&](const llvm::Record *record, const auto &def,
auto &vec) {
auto addIfInDialect = [&](const Record *record, const auto &def, auto &vec) {
return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
};

View File

@ -28,6 +28,9 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::StringMap;
//===----------------------------------------------------------------------===//
// VariableElement
@ -404,7 +407,7 @@ struct OperationFormat {
StringRef opCppClassName;
/// A map of buildable types to indices.
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes;
/// The index of the buildable type, if valid, for every operand and result.
std::vector<TypeResolution> operandTypes, resultTypes;
@ -891,8 +894,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
const NamedAttribute *var = attr->getVar();
body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
var->name);
body << formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
StringRef name = operand->getVar()->name;
@ -910,31 +912,31 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
<< name << "Operands(&" << name << "RawOperand, 1);";
}
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
" (void){0}OperandsLoc;\n",
name);
body << formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
" (void){0}OperandsLoc;\n",
name);
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
StringRef name = region->getVar()->name;
if (region->getVar()->isVariadic()) {
body << llvm::formatv(
body << formatv(
" ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
"{0}Regions;\n",
name);
} else {
body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
"std::make_unique<::mlir::Region>();\n",
name);
body << formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
"std::make_unique<::mlir::Region>();\n",
name);
}
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic()) {
body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
"{0}Successors;\n",
name);
body << formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
"{0}Successors;\n",
name);
} else {
body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
body << formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
}
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
@ -944,8 +946,8 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
else
body
<< llvm::formatv(" ::mlir::Type {0}RawType{{};\n", name)
<< llvm::formatv(
<< formatv(" ::mlir::Type {0}RawType{{};\n", name)
<< formatv(
" ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
name);
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
@ -969,27 +971,27 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
StringRef name = operand->getVar()->name;
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv("{0}OperandGroups", name);
body << formatv("{0}OperandGroups", name);
else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Operands", name);
body << formatv("{0}Operands", name);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Operand", name);
body << formatv("{0}Operand", name);
else
body << formatv("{0}RawOperand", name);
} else if (auto *region = dyn_cast<RegionVariable>(param)) {
StringRef name = region->getVar()->name;
if (region->getVar()->isVariadic())
body << llvm::formatv("{0}Regions", name);
body << formatv("{0}Regions", name);
else
body << llvm::formatv("*{0}Region", name);
body << formatv("*{0}Region", name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
StringRef name = successor->getVar()->name;
if (successor->getVar()->isVariadic())
body << llvm::formatv("{0}Successors", name);
body << formatv("{0}Successors", name);
else
body << llvm::formatv("{0}Successor", name);
body << formatv("{0}Successor", name);
} else if (auto *dir = dyn_cast<RefDirective>(param)) {
genCustomParameterParser(dir->getArg(), body);
@ -998,11 +1000,11 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv("{0}TypeGroups", listName);
body << formatv("{0}TypeGroups", listName);
else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Types", listName);
body << formatv("{0}Types", listName);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Type", listName);
body << formatv("{0}Type", listName);
else
body << formatv("{0}RawType", listName);
@ -1013,8 +1015,8 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
body << tgfmt(string->getValue(), &ctx);
} else if (auto *property = dyn_cast<PropertyVariable>(param)) {
body << llvm::formatv("result.getOrAddProperties<Properties>().{0}",
property->getVar()->name);
body << formatv("result.getOrAddProperties<Properties>().{0}",
property->getVar()->name);
} else {
llvm_unreachable("unknown custom directive parameter");
}
@ -1037,24 +1039,24 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
body << " " << var->name
<< "OperandsLoc = parser.getCurrentLocation();\n";
if (var->isOptional()) {
body << llvm::formatv(
body << formatv(
" ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
"{0}Operand;\n",
var->name);
} else if (var->isVariadicOfVariadic()) {
body << llvm::formatv(" "
"::llvm::SmallVector<::llvm::SmallVector<::mlir::"
"OpAsmParser::UnresolvedOperand>> "
"{0}OperandGroups;\n",
var->name);
body << formatv(" "
"::llvm::SmallVector<::llvm::SmallVector<::mlir::"
"OpAsmParser::UnresolvedOperand>> "
"{0}OperandGroups;\n",
var->name);
}
} else if (auto *dir = dyn_cast<TypeDirective>(param)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
body << formatv(" ::mlir::Type {0}Type;\n", listName);
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(
body << formatv(
" ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
"{0}TypeGroups;\n",
listName);
@ -1064,7 +1066,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
if (auto *operand = dyn_cast<OperandVariable>(input)) {
if (!operand->getVar()->isOptional())
continue;
body << llvm::formatv(
body << formatv(
" {0} {1}Operand = {1}Operands.empty() ? {0}() : "
"{1}Operands[0];\n",
"::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
@ -1074,9 +1076,9 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(type->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
"::mlir::Type() : {0}Types[0];\n",
listName);
body << formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
"::mlir::Type() : {0}Types[0];\n",
listName);
}
}
}
@ -1101,23 +1103,23 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
const NamedAttribute *var = attr->getVar();
if (var->attr.isOptional() || var->attr.hasDefaultValue())
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
body << formatv(" if ({0}Attr)\n ", var->name);
if (useProperties) {
body << formatv(
" result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
var->name, opCppClassName);
} else {
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
var->name);
body << formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
var->name);
}
} else if (auto *operand = dyn_cast<OperandVariable>(param)) {
const NamedTypeConstraint *var = operand->getVar();
if (var->isOptional()) {
body << llvm::formatv(" if ({0}Operand.has_value())\n"
" {0}Operands.push_back(*{0}Operand);\n",
var->name);
body << formatv(" if ({0}Operand.has_value())\n"
" {0}Operands.push_back(*{0}Operand);\n",
var->name);
} else if (var->isVariadicOfVariadic()) {
body << llvm::formatv(
body << formatv(
" for (const auto &subRange : {0}OperandGroups) {{\n"
" {0}Operands.append(subRange.begin(), subRange.end());\n"
" {0}OperandGroupSizes.push_back(subRange.size());\n"
@ -1128,11 +1130,11 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" if ({0}Type)\n"
" {0}Types.push_back({0}Type);\n",
listName);
body << formatv(" if ({0}Type)\n"
" {0}Types.push_back({0}Type);\n",
listName);
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(
body << formatv(
" for (const auto &subRange : {0}TypeGroups)\n"
" {0}Types.append(subRange.begin(), subRange.end());\n",
listName);
@ -1460,9 +1462,9 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
} else if (auto *propVar = dyn_cast<PropertyVariable>(firstElement)) {
genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false);
body << llvm::formatv("if ({0}PropParseResult.has_value() && "
"succeeded(*{0}PropParseResult)) ",
propVar->getVar()->name)
body << formatv("if ({0}PropParseResult.has_value() && "
"succeeded(*{0}PropParseResult)) ",
propVar->getVar()->name)
<< " {\n";
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
body << " if (::mlir::succeeded(parser.parseOptional";
@ -1477,13 +1479,12 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
genElementParser(regionVar, body, attrTypeCtx);
body << " if (!" << region->name << "Regions.empty()) {\n";
} else {
body << llvm::formatv(optionalRegionParserCode, region->name);
body << formatv(optionalRegionParserCode, region->name);
body << " if (!" << region->name << "Region->empty()) {\n ";
if (hasImplicitTermTrait)
body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
body << formatv(regionEnsureTerminatorParserCode, region->name);
else if (hasSingleBlockTrait)
body << llvm::formatv(regionEnsureSingleBlockParserCode,
region->name);
body << formatv(regionEnsureSingleBlockParserCode, region->name);
}
} else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
@ -1575,26 +1576,26 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv(variadicOfVariadicOperandParserCode, name);
body << formatv(variadicOfVariadicOperandParserCode, name);
else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv(variadicOperandParserCode, name);
body << formatv(variadicOperandParserCode, name);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(optionalOperandParserCode, name);
body << formatv(optionalOperandParserCode, name);
else
body << formatv(operandParserCode, name);
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
bool isVariadic = region->getVar()->isVariadic();
body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
region->getVar()->name);
body << formatv(isVariadic ? regionListParserCode : regionParserCode,
region->getVar()->name);
if (hasImplicitTermTrait)
body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
: regionEnsureTerminatorParserCode,
region->getVar()->name);
body << formatv(isVariadic ? regionListEnsureTerminatorParserCode
: regionEnsureTerminatorParserCode,
region->getVar()->name);
else if (hasSingleBlockTrait)
body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
: regionEnsureSingleBlockParserCode,
region->getVar()->name);
body << formatv(isVariadic ? regionListEnsureSingleBlockParserCode
: regionEnsureSingleBlockParserCode,
region->getVar()->name);
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
bool isVariadic = successor->getVar()->isVariadic();
@ -1631,24 +1632,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
<< " return ::mlir::failure();\n";
} else if (isa<RegionsDirective>(element)) {
body << llvm::formatv(regionListParserCode, "full");
body << formatv(regionListParserCode, "full");
if (hasImplicitTermTrait)
body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
body << formatv(regionListEnsureTerminatorParserCode, "full");
else if (hasSingleBlockTrait)
body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
body << formatv(regionListEnsureSingleBlockParserCode, "full");
} else if (isa<SuccessorsDirective>(element)) {
body << llvm::formatv(successorListParserCode, "full");
body << formatv(successorListParserCode, "full");
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
body << formatv(variadicOfVariadicTypeParserCode, listName);
} else if (lengthKind == ArgumentLengthKind::Variadic) {
body << llvm::formatv(variadicTypeParserCode, listName);
body << formatv(variadicTypeParserCode, listName);
} else if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(optionalTypeParserCode, listName);
body << formatv(optionalTypeParserCode, listName);
} else {
const char *parserCode =
dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
@ -1903,14 +1904,14 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
if (!operand.isVariadicOfVariadic())
continue;
if (op.getDialect().usePropertiesForAttributes()) {
body << llvm::formatv(
body << formatv(
" result.getOrAddProperties<{0}::Properties>().{1} = "
"parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
op.getCppClassName(),
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
operand.name);
} else {
body << llvm::formatv(
body << formatv(
" result.addAttribute(\"{0}\", "
"parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
"\n",
@ -2160,7 +2161,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
if (var->isVariadic())
body << name << "().getTypes()";
else if (var->isOptional())
body << llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
body << formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
else
body << name << "().getType()";
@ -2195,8 +2196,7 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
static void genRegionPrinter(const Twine &regionName, MethodBody &body,
bool hasImplicitTermTrait) {
if (hasImplicitTermTrait)
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
regionName);
body << formatv(regionSingleBlockImplicitTerminatorPrinterCode, regionName);
else
body << " _odsPrinter.printRegion(" << regionName << ");\n";
}
@ -2220,12 +2220,12 @@ static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
auto *operand = dyn_cast<OperandVariable>(arg);
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
if (var->isVariadicOfVariadic())
return body << llvm::formatv("{0}().join().getTypes()",
op.getGetterName(var->name));
return body << formatv("{0}().join().getTypes()",
op.getGetterName(var->name));
if (var->isVariadic())
return body << op.getGetterName(var->name) << "().getTypes()";
if (var->isOptional())
return body << llvm::formatv(
return body << formatv(
"({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
"::llvm::ArrayRef<::mlir::Type>())",
op.getGetterName(var->name));
@ -2242,10 +2242,10 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
body << llvm::formatv(enumAttrBeginPrinterCode,
(var->attr.isOptional() ? "*" : "") +
op.getGetterName(var->name),
enumAttr.getSymbolToStringFnName());
body << formatv(enumAttrBeginPrinterCode,
(var->attr.isOptional() ? "*" : "") +
op.getGetterName(var->name),
enumAttr.getSymbolToStringFnName());
// Get a string containing all of the cases that can't be represented with a
// keyword.
@ -2276,9 +2276,8 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
if (nonKeywordCases.test(it.index()))
continue;
StringRef symbol = it.value().getSymbol();
body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
llvm::isDigit(symbol.front()) ? ("_" + symbol)
: symbol);
body << formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
llvm::isDigit(symbol.front()) ? ("_" + symbol) : symbol);
}
body << " _odsPrinter << caseValueStr;\n"
" break;\n"
@ -2584,7 +2583,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
if (operand->getVar()->isVariadicOfVariadic()) {
body << llvm::formatv(
body << formatv(
" ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
"[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
"types << \")\"; });\n",
@ -2710,7 +2709,7 @@ private:
/// Verify the state of operation operands within the format.
LogicalResult
verifyOperands(SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
StringMap<TypeResolutionInstance> &variableTyResolver);
/// Verify the state of operation regions within the format.
LogicalResult verifyRegions(SMLoc loc);
@ -2718,7 +2717,7 @@ private:
/// Verify the state of operation results within the format.
LogicalResult
verifyResults(SMLoc loc,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
StringMap<TypeResolutionInstance> &variableTyResolver);
/// Verify the state of operation successors within the format.
LogicalResult verifySuccessors(SMLoc loc);
@ -2730,18 +2729,17 @@ private:
/// resolution.
void handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
StringMap<TypeResolutionInstance> &variableTyResolver);
/// Check for inferable type resolution given all operands, and or results,
/// have the same type. If 'includeResults' is true, the results also have the
/// same type as all of the operands.
void handleSameTypesConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults);
/// Check for inferable type resolution based on another operand, result, or
/// attribute.
void handleTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
const llvm::Record &def);
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
/// Returns an argument or attribute with the given name that has been seen
/// within the format.
@ -2794,9 +2792,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
"custom assembly format");
// Check for any type traits that we can use for inferring types.
llvm::StringMap<TypeResolutionInstance> variableTyResolver;
StringMap<TypeResolutionInstance> variableTyResolver;
for (const Trait &trait : op.getTraits()) {
const llvm::Record &def = trait.getDef();
const Record &def = trait.getDef();
if (def.isSubClassOf("AllTypesMatch")) {
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
variableTyResolver);
@ -2995,10 +2993,9 @@ OpFormatParser::verifyAttributeColonType(SMLoc loc,
return false;
// If we encounter `:`, the range is known to be invalid.
(void)emitError(
loc,
llvm::formatv("format ambiguity caused by `:` literal found after "
"attribute `{0}` which does not have a buildable type",
cast<AttributeVariable>(base)->getVar()->name));
loc, formatv("format ambiguity caused by `:` literal found after "
"attribute `{0}` which does not have a buildable type",
cast<AttributeVariable>(base)->getVar()->name));
return true;
};
return verifyAdjacentElements(isBase, isInvalid, elements);
@ -3018,9 +3015,9 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc,
return false;
(void)emitErrorAndNote(
loc,
llvm::formatv("format ambiguity caused by `attr-dict` directive "
"followed by region `{0}`",
region->getVar()->name),
formatv("format ambiguity caused by `attr-dict` directive "
"followed by region `{0}`",
region->getVar()->name),
"try using `attr-dict-with-keyword` instead");
return true;
};
@ -3028,7 +3025,7 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc,
}
LogicalResult OpFormatParser::verifyOperands(
SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
// Check that all of the operands are within the format, and their types can
// be inferred.
auto &buildableTypes = fmt.buildableTypes;
@ -3093,7 +3090,7 @@ LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
}
LogicalResult OpFormatParser::verifyResults(
SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
// If we format all of the types together, there is nothing to check.
if (fmt.allResultTypes)
return success();
@ -3197,7 +3194,7 @@ OpFormatParser::verifyOIListElements(SMLoc loc,
void OpFormatParser::handleAllTypesMatchConstraint(
ArrayRef<StringRef> values,
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
StringMap<TypeResolutionInstance> &variableTyResolver) {
for (unsigned i = 0, e = values.size(); i != e; ++i) {
// Check to see if this value matches a resolved operand or result type.
ConstArgument arg = findSeenArg(values[i]);
@ -3213,7 +3210,7 @@ void OpFormatParser::handleAllTypesMatchConstraint(
}
void OpFormatParser::handleSameTypesConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
StringMap<TypeResolutionInstance> &variableTyResolver,
bool includeResults) {
const NamedTypeConstraint *resolver = nullptr;
int resolvedIt = -1;
@ -3238,8 +3235,7 @@ void OpFormatParser::handleSameTypesConstraint(
}
void OpFormatParser::handleTypesMatchConstraint(
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
const llvm::Record &def) {
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
StringRef lhsName = def.getValueAsString("lhs");
StringRef rhsName = def.getValueAsString("rhs");
StringRef transformer = def.getValueAsString("transformer");

View File

@ -41,7 +41,7 @@ static std::string getOperationName(const Record &def) {
auto opName = def.getValueAsString("opName");
if (prefix.empty())
return std::string(opName);
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
return std::string(formatv("{0}.{1}", prefix, opName));
}
std::vector<const Record *>
@ -50,7 +50,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
if (!classDef)
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
std::vector<const Record *> defs;
for (const auto &def : recordKeeper.getDefs()) {
if (!def.second->isSubClassOf(classDef))
@ -70,7 +70,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
}
bool mlir::tblgen::isPythonReserved(StringRef str) {
static llvm::StringSet<> reserved({
static StringSet<> reserved({
"False", "None", "True", "and", "as", "assert", "async",
"await", "break", "class", "continue", "def", "del", "elif",
"else", "except", "finally", "for", "from", "global", "if",
@ -86,8 +86,8 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
}
void mlir::tblgen::shardOpDefinitions(
ArrayRef<const llvm::Record *> defs,
SmallVectorImpl<ArrayRef<const llvm::Record *>> &shardedDefs) {
ArrayRef<const Record *> defs,
SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) {
assert(opShardCount > 0 && "expected a positive shard count");
if (opShardCount == 1) {
shardedDefs.push_back(defs);

View File

@ -23,6 +23,8 @@
#include "llvm/TableGen/TableGenBackend.h"
using namespace mlir;
using llvm::Record;
using llvm::RecordKeeper;
using mlir::tblgen::Interface;
using mlir::tblgen::InterfaceMethod;
using mlir::tblgen::OpInterface;
@ -61,14 +63,13 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
/// Get an array of all OpInterface definitions but exclude those subclassing
/// "DeclareOpInterfaceMethods".
static std::vector<const llvm::Record *>
getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
StringRef name) {
std::vector<const llvm::Record *> defs =
static std::vector<const Record *>
getAllInterfaceDefinitions(const RecordKeeper &recordKeeper, StringRef name) {
std::vector<const Record *> defs =
recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
std::string declareName = ("Declare" + name + "InterfaceMethods").str();
llvm::erase_if(defs, [&](const llvm::Record *def) {
llvm::erase_if(defs, [&](const Record *def) {
// Ignore any "declare methods" interfaces.
if (def->isSubClassOf(declareName))
return true;
@ -88,7 +89,7 @@ public:
bool emitInterfaceDocs();
protected:
InterfaceGenerator(std::vector<const llvm::Record *> &&defs, raw_ostream &os)
InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os)
: defs(std::move(defs)), os(os) {}
void emitConceptDecl(const Interface &interface);
@ -99,7 +100,7 @@ protected:
void emitInterfaceDecl(const Interface &interface);
/// The set of interface records to emit.
std::vector<const llvm::Record *> defs;
std::vector<const Record *> defs;
// The stream to emit to.
raw_ostream &os;
/// The C++ value type of the interface, e.g. Operation*.
@ -118,7 +119,7 @@ protected:
/// A specialized generator for attribute interfaces.
struct AttrInterfaceGenerator : public InterfaceGenerator {
AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
valueType = "::mlir::Attribute";
interfaceBaseType = "AttributeInterface";
@ -133,7 +134,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
};
/// A specialized generator for operation interfaces.
struct OpInterfaceGenerator : public InterfaceGenerator {
OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
valueType = "::mlir::Operation *";
interfaceBaseType = "OpInterface";
@ -149,7 +150,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
};
/// A specialized generator for type interfaces.
struct TypeInterfaceGenerator : public InterfaceGenerator {
TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
valueType = "::mlir::Type";
interfaceBaseType = "TypeInterface";
@ -607,13 +608,13 @@ bool InterfaceGenerator::emitInterfaceDecls() {
llvm::emitSourceFileHeader("Interface Declarations", os);
// Sort according to ID, so defs are emitted in the order in which they appear
// in the Tablegen file.
std::vector<const llvm::Record *> sortedDefs(defs);
llvm::sort(sortedDefs, [](const llvm::Record *lhs, const llvm::Record *rhs) {
std::vector<const Record *> sortedDefs(defs);
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
return lhs->getID() < rhs->getID();
});
for (const llvm::Record *def : sortedDefs)
for (const Record *def : sortedDefs)
emitInterfaceDecl(Interface(def));
for (const llvm::Record *def : sortedDefs)
for (const Record *def : sortedDefs)
emitModelMethodsDef(Interface(def));
return false;
}
@ -622,8 +623,7 @@ bool InterfaceGenerator::emitInterfaceDecls() {
// GEN: Interface documentation
//===----------------------------------------------------------------------===//
static void emitInterfaceDoc(const llvm::Record &interfaceDef,
raw_ostream &os) {
static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
Interface interface(&interfaceDef);
// Emit the interface name followed by the description.
@ -684,15 +684,15 @@ struct InterfaceGenRegistration {
genDefDesc(("Generate " + genDesc + " interface definitions").str()),
genDocDesc(("Generate " + genDesc + " interface documentation").str()),
genDecls(genDeclArg, genDeclDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDecls();
}),
genDefs(genDefArg, genDefDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDefs();
}),
genDocs(genDocArg, genDocDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDocs();
}) {}

View File

@ -23,6 +23,9 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
/// File header and includes.
/// {0} is the dialect namespace.
@ -315,9 +318,9 @@ static std::string sanitizeName(StringRef name) {
}
static std::string attrSizedTraitForKind(const char *kind) {
return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
llvm::StringRef(kind).take_front().upper(),
llvm::StringRef(kind).drop_front());
return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
StringRef(kind).take_front().upper(),
StringRef(kind).drop_front());
}
/// Emits accessors to "elements" of an Op definition. Currently, the supported
@ -328,15 +331,14 @@ static void emitElementAccessors(
unsigned numVariadicGroups, unsigned numElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
assert(llvm::is_contained(
llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
kind) &&
"unsupported kind");
// Traits indicating how to process variadic elements.
std::string sameSizeTrait =
llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
llvm::StringRef(kind).take_front().upper(),
llvm::StringRef(kind).drop_front());
std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
StringRef(kind).take_front().upper(),
StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
// If there is only one variable-length element group, its size can be
@ -351,15 +353,14 @@ static void emitElementAccessors(
if (element.name.empty())
continue;
if (element.isVariableLength()) {
os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
: opOneVariadicTemplate,
sanitizeName(element.name), kind, numElements, i);
os << formatv(element.isOptional() ? opOneOptionalTemplate
: opOneVariadicTemplate,
sanitizeName(element.name), kind, numElements, i);
} else if (seenVariableLength) {
os << llvm::formatv(opSingleAfterVariableTemplate,
sanitizeName(element.name), kind, numElements, i);
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
kind, numElements, i);
} else {
os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
i);
os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
}
}
return;
@ -382,14 +383,13 @@ static void emitElementAccessors(
for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
os << llvm::formatv(opVariadicEqualPrefixTemplate,
sanitizeName(element.name), kind, numSimpleLength,
numVariadicGroups, numPrecedingSimple,
numPrecedingVariadic);
os << llvm::formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
kind);
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
kind, numSimpleLength, numVariadicGroups,
numPrecedingSimple, numPrecedingVariadic);
os << formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
kind);
}
if (element.isVariableLength())
++numPrecedingVariadic;
@ -412,9 +412,9 @@ static void emitElementAccessors(
trailing = "[0]";
else if (element.isOptional())
trailing = std::string(
llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
kind, i, trailing);
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
i, trailing);
}
return;
}
@ -459,27 +459,21 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
// Unit attributes are handled specially.
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
namedAttr.name);
os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name);
os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name);
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
continue;
}
if (namedAttr.attr.isOptional()) {
os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
namedAttr.name);
os << formatv(optionalAttributeGetterTemplate, sanitizedName,
namedAttr.name);
os << formatv(optionalAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
} else {
os << llvm::formatv(attributeGetterTemplate, sanitizedName,
namedAttr.name);
os << llvm::formatv(attributeSetterTemplate, sanitizedName,
namedAttr.name);
os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
// Non-optional attributes cannot be deleted.
}
}
@ -595,7 +589,7 @@ static bool canInferType(const Operator &op) {
/// accept them as arguments.
static void
populateBuilderArgsResults(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs) {
SmallVectorImpl<std::string> &builderArgs) {
if (canInferType(op))
return;
@ -607,7 +601,7 @@ populateBuilderArgsResults(const Operator &op,
// to properly match the built-in result accessor.
name = "result";
} else {
name = llvm::formatv("_gen_res_{0}", i);
name = formatv("_gen_res_{0}", i);
}
}
name = sanitizeName(name);
@ -620,14 +614,13 @@ populateBuilderArgsResults(const Operator &op,
/// appear in the `arguments` field of the op definition. Additionally,
/// `operandNames` is populated with names of operands in their order of
/// appearance.
static void
populateBuilderArgs(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &operandNames) {
static void populateBuilderArgs(const Operator &op,
SmallVectorImpl<std::string> &builderArgs,
SmallVectorImpl<std::string> &operandNames) {
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
std::string name = op.getArgName(i).str();
if (name.empty())
name = llvm::formatv("_gen_arg_{0}", i);
name = formatv("_gen_arg_{0}", i);
name = sanitizeName(name);
builderArgs.push_back(name);
if (!op.getArg(i).is<NamedAttribute *>())
@ -637,15 +630,16 @@ populateBuilderArgs(const Operator &op,
/// Populates `builderArgs` with the Python-compatible names of builder function
/// successor arguments. Additionally, `successorArgNames` is also populated.
static void populateBuilderArgsSuccessors(
const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &successorArgNames) {
static void
populateBuilderArgsSuccessors(const Operator &op,
SmallVectorImpl<std::string> &builderArgs,
SmallVectorImpl<std::string> &successorArgNames) {
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
NamedSuccessor successor = op.getSuccessor(i);
std::string name = std::string(successor.name);
if (name.empty())
name = llvm::formatv("_gen_successor_{0}", i);
name = formatv("_gen_successor_{0}", i);
name = sanitizeName(name);
builderArgs.push_back(name);
successorArgNames.push_back(name);
@ -658,9 +652,8 @@ static void populateBuilderArgsSuccessors(
/// operands and attributes in the same order as they appear in the `arguments`
/// field.
static void
populateBuilderLinesAttr(const Operator &op,
llvm::ArrayRef<std::string> argNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
SmallVectorImpl<std::string> &builderLines) {
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
@ -670,12 +663,12 @@ populateBuilderLinesAttr(const Operator &op,
// Unit attributes are handled specially.
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
attribute->name, argNames[i]));
builderLines.push_back(
formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
continue;
}
builderLines.push_back(llvm::formatv(
builderLines.push_back(formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
@ -686,30 +679,30 @@ populateBuilderLinesAttr(const Operator &op,
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up successors. successorArgNames is expected to correspond
/// to the Python argument name for each successor on the op.
static void populateBuilderLinesSuccessors(
const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
static void
populateBuilderLinesSuccessors(const Operator &op,
ArrayRef<std::string> successorArgNames,
SmallVectorImpl<std::string> &builderLines) {
if (successorArgNames.empty()) {
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
builderLines.push_back(formatv(initSuccessorsTemplate, "None"));
return;
}
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
builderLines.push_back(formatv(initSuccessorsTemplate, "[]"));
for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
auto &argName = successorArgNames[i];
const NamedSuccessor &successor = op.getSuccessor(i);
builderLines.push_back(
llvm::formatv(addSuccessorTemplate,
successor.isVariadic() ? "extend" : "append", argName));
builderLines.push_back(formatv(addSuccessorTemplate,
successor.isVariadic() ? "extend" : "append",
argName));
}
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op operands.
static void
populateBuilderLinesOperand(const Operator &op,
llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
// For each element, find or generate a name.
@ -718,7 +711,7 @@ populateBuilderLinesOperand(const Operator &op,
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleOperandAppendTemplate;
} else if (element.isOptional()) {
@ -738,7 +731,7 @@ populateBuilderLinesOperand(const Operator &op,
}
}
builderLines.push_back(llvm::formatv(formatString.data(), name));
builderLines.push_back(formatv(formatString.data(), name));
}
}
@ -758,7 +751,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// Appends the given multiline string as individual strings into
/// `builderLines`.
static void appendLineByLine(StringRef string,
llvm::SmallVectorImpl<std::string> &builderLines) {
SmallVectorImpl<std::string> &builderLines) {
std::pair<StringRef, StringRef> split = std::make_pair(string, string);
do {
@ -770,14 +763,13 @@ static void appendLineByLine(StringRef string,
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results.
static void
populateBuilderLinesResult(const Operator &op,
llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
if (hasSameArgumentAndResultTypes(op)) {
builderLines.push_back(llvm::formatv(
appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
builderLines.push_back(formatv(appendSameResultsTemplate,
"operands[0].type", op.getNumResults()));
return;
}
@ -785,12 +777,11 @@ populateBuilderLinesResult(const Operator &op,
const NamedAttribute &firstAttr = op.getAttribute(0);
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
"from which the type is derived");
appendLineByLine(
llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
builderLines);
builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
"_ods_derived_result_type",
op.getNumResults()));
appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
builderLines);
builderLines.push_back(formatv(appendSameResultsTemplate,
"_ods_derived_result_type",
op.getNumResults()));
return;
}
@ -803,7 +794,7 @@ populateBuilderLinesResult(const Operator &op,
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleResultAppendTemplate;
} else if (element.isOptional()) {
@ -819,17 +810,16 @@ populateBuilderLinesResult(const Operator &op,
}
}
builderLines.push_back(llvm::formatv(formatString.data(), name));
builderLines.push_back(formatv(formatString.data(), name));
}
}
/// If the operation has variadic regions, adds a builder argument to specify
/// the number of those regions and builder lines to forward it to the generic
/// constructor.
static void
populateBuilderRegions(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &builderLines) {
static void populateBuilderRegions(const Operator &op,
SmallVectorImpl<std::string> &builderArgs,
SmallVectorImpl<std::string> &builderLines) {
if (op.hasNoVariadicRegions())
return;
@ -844,19 +834,19 @@ populateBuilderRegions(const Operator &op,
.str();
builderArgs.push_back(name);
builderLines.push_back(
llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
}
/// Emits a default builder constructing an operation from the list of its
/// result types, followed by a list of its operands. Returns vector
/// of fully built functionArgs for downstream users (to save having to
/// rebuild anew).
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames;
llvm::SmallVector<std::string> successorArgNames;
static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
SmallVector<std::string> builderArgs;
SmallVector<std::string> builderLines;
SmallVector<std::string> operandArgNames;
SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgsResults(op, builderArgs);
@ -866,10 +856,10 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
populateBuilderLinesOperand(op, operandArgNames, builderLines);
populateBuilderLinesAttr(
op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines);
populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
builderLines);
populateBuilderLinesResult(
op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
populateBuilderRegions(op, builderArgs, builderLines);
@ -896,7 +886,7 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
};
// StringRefs in functionArgs refer to strings allocated by builderArgs.
llvm::SmallVector<llvm::StringRef> functionArgs;
SmallVector<StringRef> functionArgs;
// Add positional arguments.
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
@ -929,11 +919,10 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
initArgs.push_back("loc=loc");
initArgs.push_back("ip=ip");
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "),
llvm::join(initArgs, ", "));
os << formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "), llvm::join(initArgs, ", "));
return llvm::to_vector<8>(
llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
llvm::map_range(functionArgs, [](StringRef s) { return s.str(); }));
}
static void emitSegmentSpec(
@ -955,15 +944,15 @@ static void emitSegmentSpec(
}
segmentSpec.append("]");
os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
}
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
// Note that the base OpView class defines this as (0, True).
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
op.hasNoVariadicRegions() ? "True" : "False");
os << formatv(opClassRegionSpecTemplate, minRegionCount,
op.hasNoVariadicRegions() ? "True" : "False");
}
/// Emits named accessors to regions.
@ -975,20 +964,20 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
"expected only the last region to be variadic");
os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
std::to_string(en.index()) +
(region.isVariadic() ? ":" : ""));
os << formatv(regionAccessorTemplate, sanitizeName(region.name),
std::to_string(en.index()) +
(region.isVariadic() ? ":" : ""));
}
}
/// Emits builder that extracts results from op
static void emitValueBuilder(const Operator &op,
llvm::SmallVector<std::string> functionArgs,
SmallVector<std::string> functionArgs,
raw_ostream &os) {
// Params with (possibly) default args.
auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
llvm::SmallVector<llvm::StringRef> argMaybeDefault =
SmallVector<StringRef> argMaybeDefault =
llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
if (argMaybeDefault.size() == 2)
@ -1005,7 +994,7 @@ static void emitValueBuilder(const Operator &op,
});
std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1);
os << llvm::formatv(
os << formatv(
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
@ -1016,8 +1005,7 @@ static void emitValueBuilder(const Operator &op,
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
op.getOperationName());
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName());
// Sized segments.
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
@ -1028,7 +1016,7 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
}
emitRegionAttributes(op, os);
llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
emitAttributeAccessors(op, os);
emitResultAccessors(op, os);
@ -1039,17 +1027,17 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
/// Emits bindings for the dialect specified in the command line, including file
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
os << fileHeader;
if (!clDialectExtensionName.empty())
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
os << formatv(dialectExtensionTemplate, clDialectName.getValue());
else
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
os << formatv(dialectClassTemplate, clDialectName.getValue());
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
if (op.getDialectName() == clDialectName.getValue())
emitOpBindings(op, os);

View File

@ -20,6 +20,8 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::RecordKeeper;
static llvm::cl::OptionCategory
passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl");
@ -56,7 +58,7 @@ const char *const fileFooter = R"(
)";
/// Emit TODO
static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
static bool emitCAPIHeader(const RecordKeeper &records, raw_ostream &os) {
os << fileHeader;
os << "// Registration for the entire group\n";
os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName
@ -64,7 +66,7 @@ static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
Pass pass(def);
StringRef defName = pass.getDef()->getName();
os << llvm::formatv(passDecl, groupName, defName);
os << formatv(passDecl, groupName, defName);
}
os << fileFooter;
return false;
@ -91,9 +93,9 @@ void mlirRegister{0}Passes(void) {{
}
)";
static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
static bool emitCAPIImpl(const RecordKeeper &records, raw_ostream &os) {
os << "/* Autogenerated by mlir-tblgen; don't manually edit. */";
os << llvm::formatv(passGroupRegistrationCode, groupName);
os << formatv(passGroupRegistrationCode, groupName);
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
Pass pass(def);
@ -103,10 +105,9 @@ static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str();
else
constructorCall =
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
os << llvm::formatv(passCreateDef, groupName, defName, constructorCall);
os << formatv(passCreateDef, groupName, defName, constructorCall);
}
return false;
}

View File

@ -18,6 +18,7 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::RecordKeeper;
/// Emit the documentation for the given pass.
static void emitDoc(const Pass &pass, raw_ostream &os) {
@ -56,7 +57,7 @@ static void emitDoc(const Pass &pass, raw_ostream &os) {
}
}
static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
static void emitDocs(const RecordKeeper &recordKeeper, raw_ostream &os) {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
auto passDefs = recordKeeper.getAllDerivedDefinitions("PassBase");
@ -74,7 +75,7 @@ static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
static mlir::GenRegistration
genRegister("gen-pass-doc", "Generate pass documentation",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
emitDocs(records, os);
return false;
});

View File

@ -21,6 +21,8 @@
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::RecordKeeper;
static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
static llvm::cl::opt<std::string>
@ -28,7 +30,7 @@ static llvm::cl::opt<std::string>
llvm::cl::cat(passGenCat));
/// Extract the list of passes from the TableGen records.
static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) {
static std::vector<Pass> getPasses(const RecordKeeper &recordKeeper) {
std::vector<Pass> passes;
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
@ -91,7 +93,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
if (options.empty())
return;
os << llvm::formatv("struct {0}Options {{\n", passName);
os << formatv("struct {0}Options {{\n", passName);
for (const PassOption &opt : options) {
std::string type = opt.getType().str();
@ -99,7 +101,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
if (opt.isListOption())
type = "::llvm::SmallVector<" + type + ">";
os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName());
if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
os << " = " << defaultVal;
@ -128,9 +130,9 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
// Declaration of the constructor with options.
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}("
"{0}Options options);\n",
passName);
os << formatv("std::unique_ptr<::mlir::Pass> create{0}("
"{0}Options options);\n",
passName);
}
os << "#undef " << enableVarName << "\n";
@ -147,14 +149,13 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
constructorCall = constructor.str();
else
constructorCall =
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
constructorCall);
os << formatv(passRegistrationCode, pass.getDef()->getName(),
constructorCall);
}
os << llvm::formatv(passGroupRegistrationCode, groupName);
os << formatv(passGroupRegistrationCode, groupName);
for (const Pass &pass : passes)
os << " register" << pass.getDef()->getName() << "();\n";
@ -270,9 +271,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
os.indent(2) << "::mlir::Pass::"
<< (opt.isListOption() ? "ListOption" : "Option");
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
opt.getType(), opt.getCppVariableName(),
opt.getArgument(), opt.getDescription());
os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
opt.getType(), opt.getCppVariableName(), opt.getArgument(),
opt.getDescription());
if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
os << ", ::llvm::cl::init(" << defaultVal << ")";
if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
@ -284,9 +285,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
/// Emit the declarations for each of the pass statistics.
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
for (const PassStatistic &stat : pass.getStatistics()) {
os << llvm::formatv(
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
stat.getCppVariableName(), stat.getName(), stat.getDescription());
os << formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
stat.getCppVariableName(), stat.getName(),
stat.getDescription());
}
}
@ -300,11 +301,10 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
os << "#ifdef " << enableVarName << "\n";
if (emitDefaultConstructors) {
os << llvm::formatv(friendDefaultConstructorDeclTemplate, passName);
os << formatv(friendDefaultConstructorDeclTemplate, passName);
if (emitDefaultConstructorWithOptions)
os << llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate,
passName);
os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName);
}
std::string dependentDialectRegistrations;
@ -313,24 +313,23 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
llvm::interleave(
pass.getDependentDialects(), dialectsOs,
[&](StringRef dependentDialect) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
},
"\n ");
}
os << "namespace impl {\n";
os << llvm::formatv(baseClassBegin, passName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations);
os << formatv(baseClassBegin, passName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations);
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
os.indent(2) << llvm::formatv(
"{0}Base({0}Options options) : {0}Base() {{\n", passName);
os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n",
passName);
for (const PassOption &opt : pass.getOptions())
os.indent(4) << llvm::formatv("{0} = std::move(options.{0});\n",
opt.getCppVariableName());
os.indent(4) << formatv("{0} = std::move(options.{0});\n",
opt.getCppVariableName());
os.indent(2) << "}\n";
}
@ -344,21 +343,20 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
os << "private:\n";
if (emitDefaultConstructors) {
os << llvm::formatv(friendDefaultConstructorDefTemplate, passName);
os << formatv(friendDefaultConstructorDefTemplate, passName);
if (!pass.getOptions().empty())
os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate,
passName);
os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName);
}
os << "};\n";
os << "} // namespace impl\n";
if (emitDefaultConstructors) {
os << llvm::formatv(defaultConstructorDefTemplate, passName);
os << formatv(defaultConstructorDefTemplate, passName);
if (emitDefaultConstructorWithOptions)
os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName);
os << formatv(defaultConstructorWithOptionsDefTemplate, passName);
}
os << "#undef " << enableVarName << "\n";
@ -367,7 +365,7 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
static void emitPass(const Pass &pass, raw_ostream &os) {
StringRef passName = pass.getDef()->getName();
os << llvm::formatv(passHeader, passName);
os << formatv(passHeader, passName);
emitPassDecls(pass, os);
emitPassDefs(pass, os);
@ -436,21 +434,19 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
llvm::interleave(
pass.getDependentDialects(), dialectsOs,
[&](StringRef dependentDialect) {
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
dependentDialect);
dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
},
"\n ");
}
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations);
os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
dependentDialectRegistrations);
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << "};\n";
}
static void emitPasses(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
static void emitPasses(const RecordKeeper &recordKeeper, raw_ostream &os) {
std::vector<Pass> passes = getPasses(recordKeeper);
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
@ -479,7 +475,7 @@ static void emitPasses(const llvm::RecordKeeper &recordKeeper,
static mlir::GenRegistration
genPassDecls("gen-pass-decls", "Generate pass declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
[](const RecordKeeper &records, raw_ostream &os) {
emitPasses(records, os);
return false;
});

View File

@ -98,15 +98,15 @@ public:
StringRef getMergeInstance() const;
// Returns the underlying LLVM TableGen Record.
const llvm::Record *getDef() const { return def; }
const Record *getDef() const { return def; }
private:
// The TableGen definition of this availability.
const llvm::Record *def;
const Record *def;
};
} // namespace
Availability::Availability(const llvm::Record *def) : def(def) {
Availability::Availability(const Record *def) : def(def) {
assert(def->isSubClassOf("Availability") &&
"must be subclass of TableGen 'Availability' class");
}