[NFC][MLIR][TableGen] Eliminate llvm::
for commonly used types (#110841)
Eliminate `llvm::` namespace qualifier for commonly used types in MLIR TableGen backends to reduce code clutter.
This commit is contained in:
parent
906ffc4b4a
commit
bccd37f69f
@ -22,6 +22,8 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility Functions
|
||||
@ -30,14 +32,14 @@ using namespace mlir::tblgen;
|
||||
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
|
||||
/// specified and can only find one dialect's defs, use that.
|
||||
static void collectAllDefs(StringRef selectedDialect,
|
||||
ArrayRef<const llvm::Record *> records,
|
||||
ArrayRef<const Record *> records,
|
||||
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
|
||||
// Nothing to do if no defs were found.
|
||||
if (records.empty())
|
||||
return;
|
||||
|
||||
auto defs = llvm::map_range(
|
||||
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
|
||||
records, [&](const Record *rec) { return AttrOrTypeDef(rec); });
|
||||
if (selectedDialect.empty()) {
|
||||
// If a dialect was not specified, ensure that all found defs belong to the
|
||||
// same dialect.
|
||||
@ -690,15 +692,14 @@ public:
|
||||
bool emitDefs(StringRef selectedDialect);
|
||||
|
||||
protected:
|
||||
DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os,
|
||||
DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
|
||||
StringRef defType, StringRef valueType, bool isAttrGenerator)
|
||||
: defRecords(defs), os(os), defType(defType), valueType(valueType),
|
||||
isAttrGenerator(isAttrGenerator) {
|
||||
// Sort by occurrence in file.
|
||||
llvm::sort(defRecords,
|
||||
[](const llvm::Record *lhs, const llvm::Record *rhs) {
|
||||
return lhs->getID() < rhs->getID();
|
||||
});
|
||||
llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
|
||||
return lhs->getID() < rhs->getID();
|
||||
});
|
||||
}
|
||||
|
||||
/// Emit the list of def type names.
|
||||
@ -707,7 +708,7 @@ protected:
|
||||
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
|
||||
|
||||
/// The set of def records to emit.
|
||||
std::vector<const llvm::Record *> defRecords;
|
||||
std::vector<const Record *> defRecords;
|
||||
/// The attribute or type class to emit.
|
||||
/// The stream to emit to.
|
||||
raw_ostream &os;
|
||||
@ -722,13 +723,13 @@ protected:
|
||||
|
||||
/// A specialized generator for AttrDefs.
|
||||
struct AttrDefGenerator : public DefGenerator {
|
||||
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
||||
AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
|
||||
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
|
||||
};
|
||||
/// A specialized generator for TypeDefs.
|
||||
struct TypeDefGenerator : public DefGenerator {
|
||||
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
||||
TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
|
||||
"Type", "Type", /*isAttrGenerator=*/false) {}
|
||||
};
|
||||
@ -1030,9 +1031,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
|
||||
|
||||
/// Find all type constraints for which a C++ function should be generated.
|
||||
static std::vector<Constraint>
|
||||
getAllTypeConstraints(const llvm::RecordKeeper &records) {
|
||||
getAllTypeConstraints(const RecordKeeper &records) {
|
||||
std::vector<Constraint> result;
|
||||
for (const llvm::Record *def :
|
||||
for (const Record *def :
|
||||
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
|
||||
// Ignore constraints defined outside of the top-level file.
|
||||
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
|
||||
@ -1047,7 +1048,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
|
||||
return result;
|
||||
}
|
||||
|
||||
static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
|
||||
static void emitTypeConstraintDecls(const RecordKeeper &records,
|
||||
raw_ostream &os) {
|
||||
static const char *const typeConstraintDecl = R"(
|
||||
bool {0}(::mlir::Type type);
|
||||
@ -1057,7 +1058,7 @@ bool {0}(::mlir::Type type);
|
||||
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
|
||||
}
|
||||
|
||||
static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
|
||||
static void emitTypeConstraintDefs(const RecordKeeper &records,
|
||||
raw_ostream &os) {
|
||||
static const char *const typeConstraintDef = R"(
|
||||
bool {0}(::mlir::Type type) {
|
||||
@ -1088,13 +1089,13 @@ static llvm::cl::opt<std::string>
|
||||
|
||||
static mlir::GenRegistration
|
||||
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
AttrDefGenerator generator(records, os);
|
||||
return generator.emitDefs(attrDialect);
|
||||
});
|
||||
static mlir::GenRegistration
|
||||
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
AttrDefGenerator generator(records, os);
|
||||
return generator.emitDecls(attrDialect);
|
||||
});
|
||||
@ -1110,13 +1111,13 @@ static llvm::cl::opt<std::string>
|
||||
|
||||
static mlir::GenRegistration
|
||||
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
TypeDefGenerator generator(records, os);
|
||||
return generator.emitDefs(typeDialect);
|
||||
});
|
||||
static mlir::GenRegistration
|
||||
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
TypeDefGenerator generator(records, os);
|
||||
return generator.emitDecls(typeDialect);
|
||||
});
|
||||
@ -1124,14 +1125,14 @@ static mlir::GenRegistration
|
||||
static mlir::GenRegistration
|
||||
genTypeConstrDefs("gen-type-constraint-defs",
|
||||
"Generate type constraint definitions",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
emitTypeConstraintDefs(records, os);
|
||||
return false;
|
||||
});
|
||||
static mlir::GenRegistration
|
||||
genTypeConstrDecls("gen-type-constraint-decls",
|
||||
"Generate type constraint declarations",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
emitTypeConstraintDecls(records, os);
|
||||
return false;
|
||||
});
|
||||
|
@ -18,11 +18,10 @@
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
|
||||
static llvm::cl::opt<std::string>
|
||||
selectedBcDialect("bytecode-dialect",
|
||||
llvm::cl::desc("The dialect to gen for"),
|
||||
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
|
||||
static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
|
||||
static cl::opt<std::string>
|
||||
selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
|
||||
cl::cat(dialectGenCat), cl::CommaSeparated);
|
||||
|
||||
namespace {
|
||||
|
||||
@ -306,7 +305,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
|
||||
auto funScope = os.scope("{\n", "}\n\n");
|
||||
|
||||
// Check that predicates specified if multiple bytecode instances.
|
||||
for (const llvm::Record *rec : make_second_range(vec)) {
|
||||
for (const Record *rec : make_second_range(vec)) {
|
||||
StringRef pred = rec->getValueAsString("printerPredicate");
|
||||
if (vec.size() > 1 && pred.empty()) {
|
||||
for (auto [index, rec] : vec) {
|
||||
|
@ -30,6 +30,8 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
|
||||
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
|
||||
llvm::cl::opt<std::string>
|
||||
@ -39,8 +41,8 @@ llvm::cl::opt<std::string>
|
||||
/// Utility iterator used for filtering records for a specific dialect.
|
||||
namespace {
|
||||
using DialectFilterIterator =
|
||||
llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
|
||||
std::function<bool(const llvm::Record *)>>;
|
||||
llvm::filter_iterator<ArrayRef<Record *>::iterator,
|
||||
std::function<bool(const Record *)>>;
|
||||
} // namespace
|
||||
|
||||
static void populateDiscardableAttributes(
|
||||
@ -62,8 +64,8 @@ static void populateDiscardableAttributes(
|
||||
/// the given dialect.
|
||||
template <typename T>
|
||||
static iterator_range<DialectFilterIterator>
|
||||
filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
|
||||
auto filterFn = [&](const llvm::Record *record) {
|
||||
filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
|
||||
auto filterFn = [&](const Record *record) {
|
||||
return T(record).getDialect() == dialect;
|
||||
};
|
||||
return {DialectFilterIterator(records.begin(), records.end(), filterFn),
|
||||
@ -295,7 +297,7 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
|
||||
<< "::" << dialect.getCppClassName() << ")\n";
|
||||
}
|
||||
|
||||
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
|
||||
static bool emitDialectDecls(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
emitSourceFileHeader("Dialect Declarations", os, recordKeeper);
|
||||
|
||||
@ -340,8 +342,7 @@ static const char *const dialectDestructorStr = R"(
|
||||
|
||||
)";
|
||||
|
||||
static void emitDialectDef(Dialect &dialect,
|
||||
const llvm::RecordKeeper &recordKeeper,
|
||||
static void emitDialectDef(Dialect &dialect, const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
std::string cppClassName = dialect.getCppClassName();
|
||||
|
||||
@ -389,8 +390,7 @@ static void emitDialectDef(Dialect &dialect,
|
||||
os << llvm::formatv(dialectDestructorStr, cppClassName);
|
||||
}
|
||||
|
||||
static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
static bool emitDialectDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
emitSourceFileHeader("Dialect Definitions", os, recordKeeper);
|
||||
|
||||
auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
|
||||
@ -411,12 +411,12 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
|
||||
|
||||
static mlir::GenRegistration
|
||||
genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitDialectDecls(records, os);
|
||||
});
|
||||
|
||||
static mlir::GenRegistration
|
||||
genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitDialectDefs(records, os);
|
||||
});
|
||||
|
@ -21,6 +21,9 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::formatv;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
|
||||
/// File header and includes.
|
||||
constexpr const char *fileHeader = R"Py(
|
||||
@ -42,44 +45,42 @@ static std::string makePythonEnumCaseName(StringRef name) {
|
||||
|
||||
/// Emits the Python class for the given enum.
|
||||
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
|
||||
os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
|
||||
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
|
||||
os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
|
||||
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
|
||||
if (!enumAttr.getSummary().empty())
|
||||
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
|
||||
os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
|
||||
os << "\n";
|
||||
|
||||
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
|
||||
os << llvm::formatv(
|
||||
" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
|
||||
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
|
||||
: "auto()");
|
||||
os << formatv(" {0} = {1}\n",
|
||||
makePythonEnumCaseName(enumCase.getSymbol()),
|
||||
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
|
||||
: "auto()");
|
||||
}
|
||||
|
||||
os << "\n";
|
||||
|
||||
if (enumAttr.isBitEnum()) {
|
||||
os << llvm::formatv(" def __iter__(self):\n"
|
||||
" return iter([case for case in type(self) if "
|
||||
"(self & case) is case])\n");
|
||||
os << llvm::formatv(" def __len__(self):\n"
|
||||
" return bin(self).count(\"1\")\n");
|
||||
os << formatv(" def __iter__(self):\n"
|
||||
" return iter([case for case in type(self) if "
|
||||
"(self & case) is case])\n");
|
||||
os << formatv(" def __len__(self):\n"
|
||||
" return bin(self).count(\"1\")\n");
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
os << llvm::formatv(" def __str__(self):\n");
|
||||
os << formatv(" def __str__(self):\n");
|
||||
if (enumAttr.isBitEnum())
|
||||
os << llvm::formatv(" if len(self) > 1:\n"
|
||||
" return \"{0}\".join(map(str, self))\n",
|
||||
enumAttr.getDef().getValueAsString("separator"));
|
||||
os << formatv(" if len(self) > 1:\n"
|
||||
" return \"{0}\".join(map(str, self))\n",
|
||||
enumAttr.getDef().getValueAsString("separator"));
|
||||
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
|
||||
os << llvm::formatv(" if self is {0}.{1}:\n",
|
||||
enumAttr.getEnumClassName(),
|
||||
makePythonEnumCaseName(enumCase.getSymbol()));
|
||||
os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr());
|
||||
os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
|
||||
makePythonEnumCaseName(enumCase.getSymbol()));
|
||||
os << formatv(" return \"{0}\"\n", enumCase.getStr());
|
||||
}
|
||||
os << llvm::formatv(
|
||||
" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
|
||||
enumAttr.getEnumClassName());
|
||||
os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
|
||||
enumAttr.getEnumClassName());
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
@ -105,15 +106,13 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
|
||||
return true;
|
||||
}
|
||||
|
||||
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
|
||||
enumAttr.getAttrDefName());
|
||||
os << llvm::formatv("def _{0}(x, context):\n",
|
||||
enumAttr.getAttrDefName().lower());
|
||||
os << llvm::formatv(
|
||||
" return "
|
||||
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
|
||||
"context=context), int(x))\n\n",
|
||||
bitwidth);
|
||||
os << formatv("@register_attribute_builder(\"{0}\")\n",
|
||||
enumAttr.getAttrDefName());
|
||||
os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
|
||||
os << formatv(" return "
|
||||
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
|
||||
"context=context), int(x))\n\n",
|
||||
bitwidth);
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -123,26 +122,25 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
|
||||
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
|
||||
StringRef formatString,
|
||||
raw_ostream &os) {
|
||||
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
|
||||
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
|
||||
os << llvm::formatv(" return "
|
||||
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
|
||||
formatString);
|
||||
os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
|
||||
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
|
||||
os << formatv(" return "
|
||||
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
|
||||
formatString);
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Emits Python bindings for all enums in the record keeper. Returns
|
||||
/// `false` on success, `true` on failure.
|
||||
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
static bool emitPythonEnums(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
os << fileHeader;
|
||||
for (const llvm::Record *it :
|
||||
for (const Record *it :
|
||||
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
|
||||
EnumAttr enumAttr(*it);
|
||||
emitEnumClass(enumAttr, os);
|
||||
emitAttributeBuilder(enumAttr, os);
|
||||
}
|
||||
for (const llvm::Record *it :
|
||||
for (const Record *it :
|
||||
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
|
||||
AttrOrTypeDef attr(&*it);
|
||||
if (!attr.getMnemonic()) {
|
||||
@ -156,11 +154,11 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
|
||||
if (assemblyFormat == "`<` $value `>`") {
|
||||
emitDialectEnumAttributeBuilder(
|
||||
attr.getName(),
|
||||
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
|
||||
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
|
||||
} else if (assemblyFormat == "$value") {
|
||||
emitDialectEnumAttributeBuilder(
|
||||
attr.getName(),
|
||||
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
|
||||
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
|
||||
} else {
|
||||
llvm::errs()
|
||||
<< "unsupported assembly format for python enum bindings generation";
|
||||
|
@ -26,10 +26,9 @@
|
||||
using llvm::formatv;
|
||||
using llvm::isDigit;
|
||||
using llvm::PrintFatalError;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
using llvm::StringRef;
|
||||
using namespace mlir;
|
||||
using mlir::tblgen::Attribute;
|
||||
using mlir::tblgen::EnumAttr;
|
||||
using mlir::tblgen::EnumAttrCase;
|
||||
@ -139,7 +138,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
|
||||
// is not a power of two (i.e. not a single bit case) and not a known case.
|
||||
} else if (enumAttr.isBitEnum()) {
|
||||
// Process the known multi-bit cases that use valid keywords.
|
||||
llvm::SmallVector<EnumAttrCase *> validMultiBitCases;
|
||||
SmallVector<EnumAttrCase *> validMultiBitCases;
|
||||
for (auto [index, caseVal] : llvm::enumerate(cases)) {
|
||||
uint64_t value = caseVal.getValue();
|
||||
if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
|
||||
@ -476,7 +475,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef enumName = enumAttr.getEnumClassName();
|
||||
StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
|
||||
const llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
|
||||
const Record *baseAttrDef = enumAttr.getBaseAttrClass();
|
||||
Attribute baseAttr(baseAttrDef);
|
||||
|
||||
// Emit classof method
|
||||
@ -565,7 +564,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
|
||||
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
|
||||
auto enumerants = enumAttr.getAllCases();
|
||||
|
||||
llvm::SmallVector<StringRef, 2> namespaces;
|
||||
SmallVector<StringRef, 2> namespaces;
|
||||
llvm::SplitString(cppNamespace, namespaces, "::");
|
||||
|
||||
for (auto ns : namespaces)
|
||||
@ -656,7 +655,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
|
||||
EnumAttr enumAttr(enumDef);
|
||||
StringRef cppNamespace = enumAttr.getCppNamespace();
|
||||
|
||||
llvm::SmallVector<StringRef, 2> namespaces;
|
||||
SmallVector<StringRef, 2> namespaces;
|
||||
llvm::SplitString(cppNamespace, namespaces, "::");
|
||||
|
||||
for (auto ns : namespaces)
|
||||
|
@ -13,6 +13,7 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::SourceMgr;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FormatToken
|
||||
@ -26,14 +27,14 @@ SMLoc FormatToken::getLoc() const {
|
||||
// FormatLexer
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc)
|
||||
FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc)
|
||||
: mgr(mgr), loc(loc),
|
||||
curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
|
||||
curPtr(curBuffer.begin()) {}
|
||||
|
||||
FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
|
||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
||||
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
|
||||
mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
|
||||
llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
|
||||
"in custom assembly format for this operation");
|
||||
return formToken(FormatToken::error, loc.getPointer());
|
||||
}
|
||||
@ -44,10 +45,10 @@ FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
|
||||
|
||||
FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
|
||||
const Twine ¬e) {
|
||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
||||
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
|
||||
mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
|
||||
llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
|
||||
"in custom assembly format for this operation");
|
||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
|
||||
mgr.PrintMessage(loc, SourceMgr::DK_Note, note);
|
||||
return formToken(FormatToken::error, loc.getPointer());
|
||||
}
|
||||
|
||||
|
@ -411,8 +411,7 @@ public:
|
||||
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
|
||||
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
|
||||
// (Enum) to the corresponding LLVM API enumerant
|
||||
static void emitOneEnumToConversion(const llvm::Record *record,
|
||||
raw_ostream &os) {
|
||||
static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
|
||||
LLVMEnumAttr enumAttr(record);
|
||||
StringRef llvmClass = enumAttr.getLLVMClassName();
|
||||
StringRef cppClassName = enumAttr.getEnumClassName();
|
||||
@ -441,8 +440,7 @@ static void emitOneEnumToConversion(const llvm::Record *record,
|
||||
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
|
||||
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
|
||||
// (Enum) to the corresponding LLVM API C-style enumerant
|
||||
static void emitOneCEnumToConversion(const llvm::Record *record,
|
||||
raw_ostream &os) {
|
||||
static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
|
||||
LLVMCEnumAttr enumAttr(record);
|
||||
StringRef llvmClass = enumAttr.getLLVMClassName();
|
||||
StringRef cppClassName = enumAttr.getEnumClassName();
|
||||
@ -472,8 +470,7 @@ static void emitOneCEnumToConversion(const llvm::Record *record,
|
||||
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
|
||||
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
|
||||
// LLVM dialect enum attribute (Enum).
|
||||
static void emitOneEnumFromConversion(const llvm::Record *record,
|
||||
raw_ostream &os) {
|
||||
static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
|
||||
LLVMEnumAttr enumAttr(record);
|
||||
StringRef llvmClass = enumAttr.getLLVMClassName();
|
||||
StringRef cppClassName = enumAttr.getEnumClassName();
|
||||
@ -508,8 +505,7 @@ static void emitOneEnumFromConversion(const llvm::Record *record,
|
||||
// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
|
||||
// containing switch-based logic to convert from the LLVM API C-style enumerant
|
||||
// to MLIR LLVM dialect enum attribute (Enum).
|
||||
static void emitOneCEnumFromConversion(const llvm::Record *record,
|
||||
raw_ostream &os) {
|
||||
static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
|
||||
LLVMCEnumAttr enumAttr(record);
|
||||
StringRef llvmClass = enumAttr.getLLVMClassName();
|
||||
StringRef cppClassName = enumAttr.getEnumClassName();
|
||||
|
@ -24,6 +24,11 @@
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
using llvm::Regex;
|
||||
using namespace mlir;
|
||||
|
||||
static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options");
|
||||
|
||||
static llvm::cl::opt<std::string>
|
||||
@ -54,14 +59,14 @@ static llvm::cl::opt<std::string> aliasAnalysisRegexp(
|
||||
using IndicesTy = llvm::SmallBitVector;
|
||||
|
||||
/// Return a CodeGen value type entry from a type record.
|
||||
static llvm::MVT::SimpleValueType getValueType(const llvm::Record *rec) {
|
||||
static llvm::MVT::SimpleValueType getValueType(const Record *rec) {
|
||||
return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
|
||||
"Value");
|
||||
}
|
||||
|
||||
/// Return the indices of the definitions in a list of definitions that
|
||||
/// represent overloadable types
|
||||
static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record,
|
||||
static IndicesTy getOverloadableTypeIdxs(const Record &record,
|
||||
const char *listName) {
|
||||
auto results = record.getValueAsListOfDefs(listName);
|
||||
IndicesTy overloadedOps(results.size());
|
||||
@ -87,13 +92,13 @@ namespace {
|
||||
/// the fields of the record.
|
||||
class LLVMIntrinsic {
|
||||
public:
|
||||
LLVMIntrinsic(const llvm::Record &record) : record(record) {}
|
||||
LLVMIntrinsic(const Record &record) : record(record) {}
|
||||
|
||||
/// Get the name of the operation to be used in MLIR. Uses the appropriate
|
||||
/// field if not empty, constructs a name by replacing underscores with dots
|
||||
/// in the record name otherwise.
|
||||
std::string getOperationName() const {
|
||||
llvm::StringRef name = record.getValueAsString(fieldName);
|
||||
StringRef name = record.getValueAsString(fieldName);
|
||||
if (!name.empty())
|
||||
return name.str();
|
||||
|
||||
@ -101,8 +106,8 @@ public:
|
||||
assert(name.starts_with("int_") &&
|
||||
"LLVM intrinsic names are expected to start with 'int_'");
|
||||
name = name.drop_front(4);
|
||||
llvm::SmallVector<llvm::StringRef, 8> chunks;
|
||||
llvm::StringRef targetPrefix = record.getValueAsString("TargetPrefix");
|
||||
SmallVector<StringRef, 8> chunks;
|
||||
StringRef targetPrefix = record.getValueAsString("TargetPrefix");
|
||||
name.split(chunks, '_');
|
||||
auto *chunksBegin = chunks.begin();
|
||||
// Remove the target prefix from target specific intrinsics.
|
||||
@ -119,8 +124,8 @@ public:
|
||||
}
|
||||
|
||||
/// Get the name of the record without the "intrinsic" prefix.
|
||||
llvm::StringRef getProperRecordName() const {
|
||||
llvm::StringRef name = record.getName();
|
||||
StringRef getProperRecordName() const {
|
||||
StringRef name = record.getName();
|
||||
assert(name.starts_with("int_") &&
|
||||
"LLVM intrinsic names are expected to start with 'int_'");
|
||||
return name.drop_front(4);
|
||||
@ -129,10 +134,9 @@ public:
|
||||
/// Get the number of operands.
|
||||
unsigned getNumOperands() const {
|
||||
auto operands = record.getValueAsListOfDefs(fieldOperands);
|
||||
assert(llvm::all_of(operands,
|
||||
[](const llvm::Record *r) {
|
||||
return r->isSubClassOf("LLVMType");
|
||||
}) &&
|
||||
assert(llvm::all_of(
|
||||
operands,
|
||||
[](const Record *r) { return r->isSubClassOf("LLVMType"); }) &&
|
||||
"expected operands to be of LLVM type");
|
||||
return operands.size();
|
||||
}
|
||||
@ -142,7 +146,7 @@ public:
|
||||
/// structure type.
|
||||
unsigned getNumResults() const {
|
||||
auto results = record.getValueAsListOfDefs(fieldResults);
|
||||
for (const llvm::Record *r : results) {
|
||||
for (const Record *r : results) {
|
||||
(void)r;
|
||||
assert(r->isSubClassOf("LLVMType") &&
|
||||
"expected operands to be of LLVM type");
|
||||
@ -155,7 +159,7 @@ public:
|
||||
bool hasSideEffects() const {
|
||||
return llvm::none_of(
|
||||
record.getValueAsListOfDefs(fieldTraits),
|
||||
[](const llvm::Record *r) { return r->getName() == "IntrNoMem"; });
|
||||
[](const Record *r) { return r->getName() == "IntrNoMem"; });
|
||||
}
|
||||
|
||||
/// Return true if the intrinsic is commutative, i.e. has the respective
|
||||
@ -163,7 +167,7 @@ public:
|
||||
bool isCommutative() const {
|
||||
return llvm::any_of(
|
||||
record.getValueAsListOfDefs(fieldTraits),
|
||||
[](const llvm::Record *r) { return r->getName() == "Commutative"; });
|
||||
[](const Record *r) { return r->getName() == "Commutative"; });
|
||||
}
|
||||
|
||||
IndicesTy getOverloadableOperandsIdxs() const {
|
||||
@ -181,7 +185,7 @@ private:
|
||||
const char *fieldResults = "RetTypes";
|
||||
const char *fieldTraits = "IntrProperties";
|
||||
|
||||
const llvm::Record &record;
|
||||
const Record &record;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -195,27 +199,26 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
|
||||
|
||||
/// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic.
|
||||
/// Returns true on error, false on success.
|
||||
static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
|
||||
static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
|
||||
LLVMIntrinsic intr(record);
|
||||
|
||||
llvm::Regex accessGroupMatcher(accessGroupRegexp);
|
||||
Regex accessGroupMatcher(accessGroupRegexp);
|
||||
bool requiresAccessGroup =
|
||||
!accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
|
||||
|
||||
llvm::Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
|
||||
Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
|
||||
bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() &&
|
||||
aliasAnalysisMatcher.match(record.getName());
|
||||
|
||||
// Prepare strings for traits, if any.
|
||||
llvm::SmallVector<llvm::StringRef, 2> traits;
|
||||
SmallVector<StringRef, 2> traits;
|
||||
if (intr.isCommutative())
|
||||
traits.push_back("Commutative");
|
||||
if (!intr.hasSideEffects())
|
||||
traits.push_back("NoMemoryEffect");
|
||||
|
||||
// Prepare strings for operands.
|
||||
llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(),
|
||||
"LLVM_Type");
|
||||
SmallVector<StringRef, 8> operands(intr.getNumOperands(), "LLVM_Type");
|
||||
if (requiresAccessGroup)
|
||||
operands.push_back(
|
||||
"OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups");
|
||||
@ -247,14 +250,13 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
|
||||
/// Traverses the list of TableGen definitions derived from the "Intrinsic"
|
||||
/// class and generates MLIR ODS definitions for those intrinsics that have
|
||||
/// the name matching the filter.
|
||||
static bool emitIntrinsics(const llvm::RecordKeeper &records,
|
||||
llvm::raw_ostream &os) {
|
||||
static bool emitIntrinsics(const RecordKeeper &records, llvm::raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records);
|
||||
os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n";
|
||||
os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n";
|
||||
|
||||
auto defs = records.getAllDerivedDefinitions("Intrinsic");
|
||||
for (const llvm::Record *r : defs) {
|
||||
for (const Record *r : defs) {
|
||||
if (!nameFilter.empty() && !r->getName().contains(nameFilter))
|
||||
continue;
|
||||
if (emitIntrinsic(*r, os))
|
||||
|
@ -34,30 +34,30 @@
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Commandline Options
|
||||
//===----------------------------------------------------------------------===//
|
||||
static llvm::cl::OptionCategory
|
||||
docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
|
||||
llvm::cl::opt<std::string>
|
||||
stripPrefix("strip-prefix",
|
||||
llvm::cl::desc("Strip prefix of the fully qualified names"),
|
||||
llvm::cl::init("::mlir::"), llvm::cl::cat(docCat));
|
||||
llvm::cl::opt<bool> allowHugoSpecificFeatures(
|
||||
"allow-hugo-specific-features",
|
||||
llvm::cl::desc("Allows using features specific to Hugo"),
|
||||
llvm::cl::init(false), llvm::cl::cat(docCat));
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Commandline Options
|
||||
//===----------------------------------------------------------------------===//
|
||||
static cl::OptionCategory
|
||||
docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
|
||||
cl::opt<std::string>
|
||||
stripPrefix("strip-prefix",
|
||||
cl::desc("Strip prefix of the fully qualified names"),
|
||||
cl::init("::mlir::"), cl::cat(docCat));
|
||||
cl::opt<bool> allowHugoSpecificFeatures(
|
||||
"allow-hugo-specific-features",
|
||||
cl::desc("Allows using features specific to Hugo"), cl::init(false),
|
||||
cl::cat(docCat));
|
||||
|
||||
void mlir::tblgen::emitSummary(StringRef summary, raw_ostream &os) {
|
||||
if (!summary.empty()) {
|
||||
llvm::StringRef trimmed = summary.trim();
|
||||
StringRef trimmed = summary.trim();
|
||||
char first = std::toupper(trimmed.front());
|
||||
llvm::StringRef rest = trimmed.drop_front();
|
||||
StringRef rest = trimmed.drop_front();
|
||||
os << "\n_" << first << rest << "_\n\n";
|
||||
}
|
||||
}
|
||||
@ -152,10 +152,10 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
|
||||
effectName.consume_front("::");
|
||||
effectName.consume_front("mlir::");
|
||||
std::string effectStr;
|
||||
llvm::raw_string_ostream os(effectStr);
|
||||
raw_string_ostream os(effectStr);
|
||||
os << effectName << "{";
|
||||
auto list = trait.getDef().getValueAsListOfDefs("effects");
|
||||
llvm::interleaveComma(list, os, [&](const Record *rec) {
|
||||
interleaveComma(list, os, [&](const Record *rec) {
|
||||
StringRef effect = rec->getValueAsString("effect");
|
||||
effect.consume_front("::");
|
||||
effect.consume_front("mlir::");
|
||||
@ -163,7 +163,7 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
|
||||
});
|
||||
os << "}";
|
||||
effects.insert(backticks(effectStr));
|
||||
name.append(llvm::formatv(" ({0})", traitName).str());
|
||||
name.append(formatv(" ({0})", traitName).str());
|
||||
}
|
||||
interfaces.insert(backticks(name));
|
||||
continue;
|
||||
@ -172,15 +172,15 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
|
||||
traits.insert(backticks(name));
|
||||
}
|
||||
if (!traits.empty()) {
|
||||
llvm::interleaveComma(traits, os << "\nTraits: ");
|
||||
interleaveComma(traits, os << "\nTraits: ");
|
||||
os << "\n";
|
||||
}
|
||||
if (!interfaces.empty()) {
|
||||
llvm::interleaveComma(interfaces, os << "\nInterfaces: ");
|
||||
interleaveComma(interfaces, os << "\nInterfaces: ");
|
||||
os << "\n";
|
||||
}
|
||||
if (!effects.empty()) {
|
||||
llvm::interleaveComma(effects, os << "\nEffects: ");
|
||||
interleaveComma(effects, os << "\nEffects: ");
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
@ -196,7 +196,7 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) {
|
||||
std::string classNameStr = op.getQualCppClassName();
|
||||
StringRef className = classNameStr;
|
||||
(void)className.consume_front(stripPrefix);
|
||||
os << llvm::formatv("### `{0}` ({1})\n", op.getOperationName(), className);
|
||||
os << formatv("### `{0}` ({1})\n", op.getOperationName(), className);
|
||||
|
||||
// Emit the summary, syntax, and description if present.
|
||||
if (op.hasSummary())
|
||||
@ -287,7 +287,7 @@ static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
|
||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||
emitSourceLink(recordKeeper.getInputFilename(), os);
|
||||
for (const llvm::Record *opDef : opDefs)
|
||||
for (const Record *opDef : opDefs)
|
||||
emitOpDoc(Operator(opDef), os);
|
||||
}
|
||||
|
||||
@ -339,7 +339,7 @@ static void emitAttrOrTypeDefAssemblyFormat(const AttrOrTypeDef &def,
|
||||
}
|
||||
|
||||
static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
|
||||
os << llvm::formatv("### {0}\n", def.getCppClassName());
|
||||
os << formatv("### {0}\n", def.getCppClassName());
|
||||
|
||||
// Emit the summary if present.
|
||||
if (def.hasSummary())
|
||||
@ -376,7 +376,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions(recordTypeName);
|
||||
|
||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||
for (const llvm::Record *def : defs)
|
||||
for (const Record *def : defs)
|
||||
emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os);
|
||||
}
|
||||
|
||||
@ -385,7 +385,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
|
||||
os << llvm::formatv("### {0}\n", def.getEnumClassName());
|
||||
os << formatv("### {0}\n", def.getEnumClassName());
|
||||
|
||||
// Emit the summary if present.
|
||||
if (!def.getSummary().empty())
|
||||
@ -406,8 +406,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
|
||||
|
||||
static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||
for (const llvm::Record *def :
|
||||
recordKeeper.getAllDerivedDefinitions("EnumAttr"))
|
||||
for (const Record *def : recordKeeper.getAllDerivedDefinitions("EnumAttr"))
|
||||
emitEnumDoc(EnumAttr(def), os);
|
||||
}
|
||||
|
||||
@ -431,7 +430,7 @@ struct OpDocGroup {
|
||||
static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
|
||||
raw_ostream &os) {
|
||||
std::string str;
|
||||
llvm::raw_string_ostream ss(str);
|
||||
raw_string_ostream ss(str);
|
||||
fn(ss);
|
||||
for (StringRef x : llvm::split(str, "\n")) {
|
||||
if (nest && x.starts_with("#"))
|
||||
@ -507,7 +506,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
|
||||
emitIfNotEmpty(dialect.getDescription(), os);
|
||||
|
||||
// Generate a TOC marker except if description already contains one.
|
||||
llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline);
|
||||
Regex r("^[[:space:]]*\\[TOC\\]$", Regex::RegexFlags::Newline);
|
||||
if (!r.match(dialect.getDescription()))
|
||||
os << "[TOC]\n\n";
|
||||
|
||||
@ -537,17 +536,15 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
std::vector<TypeDef> dialectTypeDefs;
|
||||
std::vector<EnumAttr> dialectEnums;
|
||||
|
||||
llvm::SmallDenseSet<const Record *> seen;
|
||||
auto addIfNotSeen = [&](const llvm::Record *record, const auto &def,
|
||||
auto &vec) {
|
||||
SmallDenseSet<const Record *> seen;
|
||||
auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) {
|
||||
if (seen.insert(record).second) {
|
||||
vec.push_back(def);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
auto addIfInDialect = [&](const llvm::Record *record, const auto &def,
|
||||
auto &vec) {
|
||||
auto addIfInDialect = [&](const Record *record, const auto &def, auto &vec) {
|
||||
return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
|
||||
};
|
||||
|
||||
|
@ -28,6 +28,9 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::formatv;
|
||||
using llvm::Record;
|
||||
using llvm::StringMap;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VariableElement
|
||||
@ -404,7 +407,7 @@ struct OperationFormat {
|
||||
StringRef opCppClassName;
|
||||
|
||||
/// A map of buildable types to indices.
|
||||
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
|
||||
llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes;
|
||||
|
||||
/// The index of the buildable type, if valid, for every operand and result.
|
||||
std::vector<TypeResolution> operandTypes, resultTypes;
|
||||
@ -891,8 +894,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
||||
|
||||
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
|
||||
const NamedAttribute *var = attr->getVar();
|
||||
body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
|
||||
var->name);
|
||||
body << formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), var->name);
|
||||
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
|
||||
StringRef name = operand->getVar()->name;
|
||||
@ -910,31 +912,31 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
||||
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
|
||||
<< name << "Operands(&" << name << "RawOperand, 1);";
|
||||
}
|
||||
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
|
||||
" (void){0}OperandsLoc;\n",
|
||||
name);
|
||||
body << formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
|
||||
" (void){0}OperandsLoc;\n",
|
||||
name);
|
||||
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
||||
StringRef name = region->getVar()->name;
|
||||
if (region->getVar()->isVariadic()) {
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
|
||||
"{0}Regions;\n",
|
||||
name);
|
||||
} else {
|
||||
body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
|
||||
"std::make_unique<::mlir::Region>();\n",
|
||||
name);
|
||||
body << formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
|
||||
"std::make_unique<::mlir::Region>();\n",
|
||||
name);
|
||||
}
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
||||
StringRef name = successor->getVar()->name;
|
||||
if (successor->getVar()->isVariadic()) {
|
||||
body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
|
||||
"{0}Successors;\n",
|
||||
name);
|
||||
body << formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
|
||||
"{0}Successors;\n",
|
||||
name);
|
||||
} else {
|
||||
body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
|
||||
body << formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
|
||||
}
|
||||
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||
@ -944,8 +946,8 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
||||
body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
|
||||
else
|
||||
body
|
||||
<< llvm::formatv(" ::mlir::Type {0}RawType{{};\n", name)
|
||||
<< llvm::formatv(
|
||||
<< formatv(" ::mlir::Type {0}RawType{{};\n", name)
|
||||
<< formatv(
|
||||
" ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
|
||||
name);
|
||||
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
|
||||
@ -969,27 +971,27 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
|
||||
StringRef name = operand->getVar()->name;
|
||||
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
||||
body << llvm::formatv("{0}OperandGroups", name);
|
||||
body << formatv("{0}OperandGroups", name);
|
||||
else if (lengthKind == ArgumentLengthKind::Variadic)
|
||||
body << llvm::formatv("{0}Operands", name);
|
||||
body << formatv("{0}Operands", name);
|
||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||
body << llvm::formatv("{0}Operand", name);
|
||||
body << formatv("{0}Operand", name);
|
||||
else
|
||||
body << formatv("{0}RawOperand", name);
|
||||
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(param)) {
|
||||
StringRef name = region->getVar()->name;
|
||||
if (region->getVar()->isVariadic())
|
||||
body << llvm::formatv("{0}Regions", name);
|
||||
body << formatv("{0}Regions", name);
|
||||
else
|
||||
body << llvm::formatv("*{0}Region", name);
|
||||
body << formatv("*{0}Region", name);
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
|
||||
StringRef name = successor->getVar()->name;
|
||||
if (successor->getVar()->isVariadic())
|
||||
body << llvm::formatv("{0}Successors", name);
|
||||
body << formatv("{0}Successors", name);
|
||||
else
|
||||
body << llvm::formatv("{0}Successor", name);
|
||||
body << formatv("{0}Successor", name);
|
||||
|
||||
} else if (auto *dir = dyn_cast<RefDirective>(param)) {
|
||||
genCustomParameterParser(dir->getArg(), body);
|
||||
@ -998,11 +1000,11 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
||||
body << llvm::formatv("{0}TypeGroups", listName);
|
||||
body << formatv("{0}TypeGroups", listName);
|
||||
else if (lengthKind == ArgumentLengthKind::Variadic)
|
||||
body << llvm::formatv("{0}Types", listName);
|
||||
body << formatv("{0}Types", listName);
|
||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||
body << llvm::formatv("{0}Type", listName);
|
||||
body << formatv("{0}Type", listName);
|
||||
else
|
||||
body << formatv("{0}RawType", listName);
|
||||
|
||||
@ -1013,8 +1015,8 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
|
||||
body << tgfmt(string->getValue(), &ctx);
|
||||
|
||||
} else if (auto *property = dyn_cast<PropertyVariable>(param)) {
|
||||
body << llvm::formatv("result.getOrAddProperties<Properties>().{0}",
|
||||
property->getVar()->name);
|
||||
body << formatv("result.getOrAddProperties<Properties>().{0}",
|
||||
property->getVar()->name);
|
||||
} else {
|
||||
llvm_unreachable("unknown custom directive parameter");
|
||||
}
|
||||
@ -1037,24 +1039,24 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
body << " " << var->name
|
||||
<< "OperandsLoc = parser.getCurrentLocation();\n";
|
||||
if (var->isOptional()) {
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
|
||||
"{0}Operand;\n",
|
||||
var->name);
|
||||
} else if (var->isVariadicOfVariadic()) {
|
||||
body << llvm::formatv(" "
|
||||
"::llvm::SmallVector<::llvm::SmallVector<::mlir::"
|
||||
"OpAsmParser::UnresolvedOperand>> "
|
||||
"{0}OperandGroups;\n",
|
||||
var->name);
|
||||
body << formatv(" "
|
||||
"::llvm::SmallVector<::llvm::SmallVector<::mlir::"
|
||||
"OpAsmParser::UnresolvedOperand>> "
|
||||
"{0}OperandGroups;\n",
|
||||
var->name);
|
||||
}
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(param)) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::Optional) {
|
||||
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
|
||||
body << formatv(" ::mlir::Type {0}Type;\n", listName);
|
||||
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
|
||||
"{0}TypeGroups;\n",
|
||||
listName);
|
||||
@ -1064,7 +1066,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
if (auto *operand = dyn_cast<OperandVariable>(input)) {
|
||||
if (!operand->getVar()->isOptional())
|
||||
continue;
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" {0} {1}Operand = {1}Operands.empty() ? {0}() : "
|
||||
"{1}Operands[0];\n",
|
||||
"::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
|
||||
@ -1074,9 +1076,9 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(type->getArg(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::Optional) {
|
||||
body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
|
||||
"::mlir::Type() : {0}Types[0];\n",
|
||||
listName);
|
||||
body << formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
|
||||
"::mlir::Type() : {0}Types[0];\n",
|
||||
listName);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1101,23 +1103,23 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
|
||||
const NamedAttribute *var = attr->getVar();
|
||||
if (var->attr.isOptional() || var->attr.hasDefaultValue())
|
||||
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
|
||||
body << formatv(" if ({0}Attr)\n ", var->name);
|
||||
if (useProperties) {
|
||||
body << formatv(
|
||||
" result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
|
||||
var->name, opCppClassName);
|
||||
} else {
|
||||
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
|
||||
var->name);
|
||||
body << formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
|
||||
var->name);
|
||||
}
|
||||
} else if (auto *operand = dyn_cast<OperandVariable>(param)) {
|
||||
const NamedTypeConstraint *var = operand->getVar();
|
||||
if (var->isOptional()) {
|
||||
body << llvm::formatv(" if ({0}Operand.has_value())\n"
|
||||
" {0}Operands.push_back(*{0}Operand);\n",
|
||||
var->name);
|
||||
body << formatv(" if ({0}Operand.has_value())\n"
|
||||
" {0}Operands.push_back(*{0}Operand);\n",
|
||||
var->name);
|
||||
} else if (var->isVariadicOfVariadic()) {
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" for (const auto &subRange : {0}OperandGroups) {{\n"
|
||||
" {0}Operands.append(subRange.begin(), subRange.end());\n"
|
||||
" {0}OperandGroupSizes.push_back(subRange.size());\n"
|
||||
@ -1128,11 +1130,11 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::Optional) {
|
||||
body << llvm::formatv(" if ({0}Type)\n"
|
||||
" {0}Types.push_back({0}Type);\n",
|
||||
listName);
|
||||
body << formatv(" if ({0}Type)\n"
|
||||
" {0}Types.push_back({0}Type);\n",
|
||||
listName);
|
||||
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" for (const auto &subRange : {0}TypeGroups)\n"
|
||||
" {0}Types.append(subRange.begin(), subRange.end());\n",
|
||||
listName);
|
||||
@ -1460,9 +1462,9 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
||||
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
|
||||
} else if (auto *propVar = dyn_cast<PropertyVariable>(firstElement)) {
|
||||
genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false);
|
||||
body << llvm::formatv("if ({0}PropParseResult.has_value() && "
|
||||
"succeeded(*{0}PropParseResult)) ",
|
||||
propVar->getVar()->name)
|
||||
body << formatv("if ({0}PropParseResult.has_value() && "
|
||||
"succeeded(*{0}PropParseResult)) ",
|
||||
propVar->getVar()->name)
|
||||
<< " {\n";
|
||||
} else if (auto *literal = dyn_cast<LiteralElement>(firstElement)) {
|
||||
body << " if (::mlir::succeeded(parser.parseOptional";
|
||||
@ -1477,13 +1479,12 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
||||
genElementParser(regionVar, body, attrTypeCtx);
|
||||
body << " if (!" << region->name << "Regions.empty()) {\n";
|
||||
} else {
|
||||
body << llvm::formatv(optionalRegionParserCode, region->name);
|
||||
body << formatv(optionalRegionParserCode, region->name);
|
||||
body << " if (!" << region->name << "Region->empty()) {\n ";
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
|
||||
body << formatv(regionEnsureTerminatorParserCode, region->name);
|
||||
else if (hasSingleBlockTrait)
|
||||
body << llvm::formatv(regionEnsureSingleBlockParserCode,
|
||||
region->name);
|
||||
body << formatv(regionEnsureSingleBlockParserCode, region->name);
|
||||
}
|
||||
} else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
|
||||
body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
|
||||
@ -1575,26 +1576,26 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
||||
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
||||
StringRef name = operand->getVar()->name;
|
||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
||||
body << llvm::formatv(variadicOfVariadicOperandParserCode, name);
|
||||
body << formatv(variadicOfVariadicOperandParserCode, name);
|
||||
else if (lengthKind == ArgumentLengthKind::Variadic)
|
||||
body << llvm::formatv(variadicOperandParserCode, name);
|
||||
body << formatv(variadicOperandParserCode, name);
|
||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||
body << llvm::formatv(optionalOperandParserCode, name);
|
||||
body << formatv(optionalOperandParserCode, name);
|
||||
else
|
||||
body << formatv(operandParserCode, name);
|
||||
|
||||
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
||||
bool isVariadic = region->getVar()->isVariadic();
|
||||
body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
|
||||
region->getVar()->name);
|
||||
body << formatv(isVariadic ? regionListParserCode : regionParserCode,
|
||||
region->getVar()->name);
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
|
||||
: regionEnsureTerminatorParserCode,
|
||||
region->getVar()->name);
|
||||
body << formatv(isVariadic ? regionListEnsureTerminatorParserCode
|
||||
: regionEnsureTerminatorParserCode,
|
||||
region->getVar()->name);
|
||||
else if (hasSingleBlockTrait)
|
||||
body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
|
||||
: regionEnsureSingleBlockParserCode,
|
||||
region->getVar()->name);
|
||||
body << formatv(isVariadic ? regionListEnsureSingleBlockParserCode
|
||||
: regionEnsureSingleBlockParserCode,
|
||||
region->getVar()->name);
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
||||
bool isVariadic = successor->getVar()->isVariadic();
|
||||
@ -1631,24 +1632,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
||||
<< " return ::mlir::failure();\n";
|
||||
|
||||
} else if (isa<RegionsDirective>(element)) {
|
||||
body << llvm::formatv(regionListParserCode, "full");
|
||||
body << formatv(regionListParserCode, "full");
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
|
||||
body << formatv(regionListEnsureTerminatorParserCode, "full");
|
||||
else if (hasSingleBlockTrait)
|
||||
body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
|
||||
body << formatv(regionListEnsureSingleBlockParserCode, "full");
|
||||
|
||||
} else if (isa<SuccessorsDirective>(element)) {
|
||||
body << llvm::formatv(successorListParserCode, "full");
|
||||
body << formatv(successorListParserCode, "full");
|
||||
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||
ArgumentLengthKind lengthKind;
|
||||
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
||||
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
|
||||
body << formatv(variadicOfVariadicTypeParserCode, listName);
|
||||
} else if (lengthKind == ArgumentLengthKind::Variadic) {
|
||||
body << llvm::formatv(variadicTypeParserCode, listName);
|
||||
body << formatv(variadicTypeParserCode, listName);
|
||||
} else if (lengthKind == ArgumentLengthKind::Optional) {
|
||||
body << llvm::formatv(optionalTypeParserCode, listName);
|
||||
body << formatv(optionalTypeParserCode, listName);
|
||||
} else {
|
||||
const char *parserCode =
|
||||
dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
|
||||
@ -1903,14 +1904,14 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
|
||||
if (!operand.isVariadicOfVariadic())
|
||||
continue;
|
||||
if (op.getDialect().usePropertiesForAttributes()) {
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" result.getOrAddProperties<{0}::Properties>().{1} = "
|
||||
"parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
|
||||
op.getCppClassName(),
|
||||
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
|
||||
operand.name);
|
||||
} else {
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" result.addAttribute(\"{0}\", "
|
||||
"parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
|
||||
"\n",
|
||||
@ -2160,7 +2161,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
|
||||
if (var->isVariadic())
|
||||
body << name << "().getTypes()";
|
||||
else if (var->isOptional())
|
||||
body << llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
|
||||
body << formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
|
||||
else
|
||||
body << name << "().getType()";
|
||||
|
||||
@ -2195,8 +2196,7 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
|
||||
static void genRegionPrinter(const Twine ®ionName, MethodBody &body,
|
||||
bool hasImplicitTermTrait) {
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
|
||||
regionName);
|
||||
body << formatv(regionSingleBlockImplicitTerminatorPrinterCode, regionName);
|
||||
else
|
||||
body << " _odsPrinter.printRegion(" << regionName << ");\n";
|
||||
}
|
||||
@ -2220,12 +2220,12 @@ static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
|
||||
auto *operand = dyn_cast<OperandVariable>(arg);
|
||||
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
|
||||
if (var->isVariadicOfVariadic())
|
||||
return body << llvm::formatv("{0}().join().getTypes()",
|
||||
op.getGetterName(var->name));
|
||||
return body << formatv("{0}().join().getTypes()",
|
||||
op.getGetterName(var->name));
|
||||
if (var->isVariadic())
|
||||
return body << op.getGetterName(var->name) << "().getTypes()";
|
||||
if (var->isOptional())
|
||||
return body << llvm::formatv(
|
||||
return body << formatv(
|
||||
"({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
|
||||
"::llvm::ArrayRef<::mlir::Type>())",
|
||||
op.getGetterName(var->name));
|
||||
@ -2242,10 +2242,10 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
||||
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
|
||||
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
|
||||
|
||||
body << llvm::formatv(enumAttrBeginPrinterCode,
|
||||
(var->attr.isOptional() ? "*" : "") +
|
||||
op.getGetterName(var->name),
|
||||
enumAttr.getSymbolToStringFnName());
|
||||
body << formatv(enumAttrBeginPrinterCode,
|
||||
(var->attr.isOptional() ? "*" : "") +
|
||||
op.getGetterName(var->name),
|
||||
enumAttr.getSymbolToStringFnName());
|
||||
|
||||
// Get a string containing all of the cases that can't be represented with a
|
||||
// keyword.
|
||||
@ -2276,9 +2276,8 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
||||
if (nonKeywordCases.test(it.index()))
|
||||
continue;
|
||||
StringRef symbol = it.value().getSymbol();
|
||||
body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
|
||||
llvm::isDigit(symbol.front()) ? ("_" + symbol)
|
||||
: symbol);
|
||||
body << formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
|
||||
llvm::isDigit(symbol.front()) ? ("_" + symbol) : symbol);
|
||||
}
|
||||
body << " _odsPrinter << caseValueStr;\n"
|
||||
" break;\n"
|
||||
@ -2584,7 +2583,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
|
||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||
if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
|
||||
if (operand->getVar()->isVariadicOfVariadic()) {
|
||||
body << llvm::formatv(
|
||||
body << formatv(
|
||||
" ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
|
||||
"[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
|
||||
"types << \")\"; });\n",
|
||||
@ -2710,7 +2709,7 @@ private:
|
||||
/// Verify the state of operation operands within the format.
|
||||
LogicalResult
|
||||
verifyOperands(SMLoc loc,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
|
||||
/// Verify the state of operation regions within the format.
|
||||
LogicalResult verifyRegions(SMLoc loc);
|
||||
@ -2718,7 +2717,7 @@ private:
|
||||
/// Verify the state of operation results within the format.
|
||||
LogicalResult
|
||||
verifyResults(SMLoc loc,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
|
||||
/// Verify the state of operation successors within the format.
|
||||
LogicalResult verifySuccessors(SMLoc loc);
|
||||
@ -2730,18 +2729,17 @@ private:
|
||||
/// resolution.
|
||||
void handleAllTypesMatchConstraint(
|
||||
ArrayRef<StringRef> values,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||
/// Check for inferable type resolution given all operands, and or results,
|
||||
/// have the same type. If 'includeResults' is true, the results also have the
|
||||
/// same type as all of the operands.
|
||||
void handleSameTypesConstraint(
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
bool includeResults);
|
||||
/// Check for inferable type resolution based on another operand, result, or
|
||||
/// attribute.
|
||||
void handleTypesMatchConstraint(
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
const llvm::Record &def);
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
|
||||
|
||||
/// Returns an argument or attribute with the given name that has been seen
|
||||
/// within the format.
|
||||
@ -2794,9 +2792,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
|
||||
"custom assembly format");
|
||||
|
||||
// Check for any type traits that we can use for inferring types.
|
||||
llvm::StringMap<TypeResolutionInstance> variableTyResolver;
|
||||
StringMap<TypeResolutionInstance> variableTyResolver;
|
||||
for (const Trait &trait : op.getTraits()) {
|
||||
const llvm::Record &def = trait.getDef();
|
||||
const Record &def = trait.getDef();
|
||||
if (def.isSubClassOf("AllTypesMatch")) {
|
||||
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
|
||||
variableTyResolver);
|
||||
@ -2995,10 +2993,9 @@ OpFormatParser::verifyAttributeColonType(SMLoc loc,
|
||||
return false;
|
||||
// If we encounter `:`, the range is known to be invalid.
|
||||
(void)emitError(
|
||||
loc,
|
||||
llvm::formatv("format ambiguity caused by `:` literal found after "
|
||||
"attribute `{0}` which does not have a buildable type",
|
||||
cast<AttributeVariable>(base)->getVar()->name));
|
||||
loc, formatv("format ambiguity caused by `:` literal found after "
|
||||
"attribute `{0}` which does not have a buildable type",
|
||||
cast<AttributeVariable>(base)->getVar()->name));
|
||||
return true;
|
||||
};
|
||||
return verifyAdjacentElements(isBase, isInvalid, elements);
|
||||
@ -3018,9 +3015,9 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc,
|
||||
return false;
|
||||
(void)emitErrorAndNote(
|
||||
loc,
|
||||
llvm::formatv("format ambiguity caused by `attr-dict` directive "
|
||||
"followed by region `{0}`",
|
||||
region->getVar()->name),
|
||||
formatv("format ambiguity caused by `attr-dict` directive "
|
||||
"followed by region `{0}`",
|
||||
region->getVar()->name),
|
||||
"try using `attr-dict-with-keyword` instead");
|
||||
return true;
|
||||
};
|
||||
@ -3028,7 +3025,7 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc,
|
||||
}
|
||||
|
||||
LogicalResult OpFormatParser::verifyOperands(
|
||||
SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
// Check that all of the operands are within the format, and their types can
|
||||
// be inferred.
|
||||
auto &buildableTypes = fmt.buildableTypes;
|
||||
@ -3093,7 +3090,7 @@ LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
|
||||
}
|
||||
|
||||
LogicalResult OpFormatParser::verifyResults(
|
||||
SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
// If we format all of the types together, there is nothing to check.
|
||||
if (fmt.allResultTypes)
|
||||
return success();
|
||||
@ -3197,7 +3194,7 @@ OpFormatParser::verifyOIListElements(SMLoc loc,
|
||||
|
||||
void OpFormatParser::handleAllTypesMatchConstraint(
|
||||
ArrayRef<StringRef> values,
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
||||
// Check to see if this value matches a resolved operand or result type.
|
||||
ConstArgument arg = findSeenArg(values[i]);
|
||||
@ -3213,7 +3210,7 @@ void OpFormatParser::handleAllTypesMatchConstraint(
|
||||
}
|
||||
|
||||
void OpFormatParser::handleSameTypesConstraint(
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
bool includeResults) {
|
||||
const NamedTypeConstraint *resolver = nullptr;
|
||||
int resolvedIt = -1;
|
||||
@ -3238,8 +3235,7 @@ void OpFormatParser::handleSameTypesConstraint(
|
||||
}
|
||||
|
||||
void OpFormatParser::handleTypesMatchConstraint(
|
||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||
const llvm::Record &def) {
|
||||
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
|
||||
StringRef lhsName = def.getValueAsString("lhs");
|
||||
StringRef rhsName = def.getValueAsString("rhs");
|
||||
StringRef transformer = def.getValueAsString("transformer");
|
||||
|
@ -41,7 +41,7 @@ static std::string getOperationName(const Record &def) {
|
||||
auto opName = def.getValueAsString("opName");
|
||||
if (prefix.empty())
|
||||
return std::string(opName);
|
||||
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
|
||||
return std::string(formatv("{0}.{1}", prefix, opName));
|
||||
}
|
||||
|
||||
std::vector<const Record *>
|
||||
@ -50,7 +50,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
|
||||
if (!classDef)
|
||||
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
|
||||
|
||||
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
|
||||
Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
|
||||
std::vector<const Record *> defs;
|
||||
for (const auto &def : recordKeeper.getDefs()) {
|
||||
if (!def.second->isSubClassOf(classDef))
|
||||
@ -70,7 +70,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
|
||||
}
|
||||
|
||||
bool mlir::tblgen::isPythonReserved(StringRef str) {
|
||||
static llvm::StringSet<> reserved({
|
||||
static StringSet<> reserved({
|
||||
"False", "None", "True", "and", "as", "assert", "async",
|
||||
"await", "break", "class", "continue", "def", "del", "elif",
|
||||
"else", "except", "finally", "for", "from", "global", "if",
|
||||
@ -86,8 +86,8 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
|
||||
}
|
||||
|
||||
void mlir::tblgen::shardOpDefinitions(
|
||||
ArrayRef<const llvm::Record *> defs,
|
||||
SmallVectorImpl<ArrayRef<const llvm::Record *>> &shardedDefs) {
|
||||
ArrayRef<const Record *> defs,
|
||||
SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) {
|
||||
assert(opShardCount > 0 && "expected a positive shard count");
|
||||
if (opShardCount == 1) {
|
||||
shardedDefs.push_back(defs);
|
||||
|
@ -23,6 +23,8 @@
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
using namespace mlir;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
using mlir::tblgen::Interface;
|
||||
using mlir::tblgen::InterfaceMethod;
|
||||
using mlir::tblgen::OpInterface;
|
||||
@ -61,14 +63,13 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
|
||||
|
||||
/// Get an array of all OpInterface definitions but exclude those subclassing
|
||||
/// "DeclareOpInterfaceMethods".
|
||||
static std::vector<const llvm::Record *>
|
||||
getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
|
||||
StringRef name) {
|
||||
std::vector<const llvm::Record *> defs =
|
||||
static std::vector<const Record *>
|
||||
getAllInterfaceDefinitions(const RecordKeeper &recordKeeper, StringRef name) {
|
||||
std::vector<const Record *> defs =
|
||||
recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
|
||||
|
||||
std::string declareName = ("Declare" + name + "InterfaceMethods").str();
|
||||
llvm::erase_if(defs, [&](const llvm::Record *def) {
|
||||
llvm::erase_if(defs, [&](const Record *def) {
|
||||
// Ignore any "declare methods" interfaces.
|
||||
if (def->isSubClassOf(declareName))
|
||||
return true;
|
||||
@ -88,7 +89,7 @@ public:
|
||||
bool emitInterfaceDocs();
|
||||
|
||||
protected:
|
||||
InterfaceGenerator(std::vector<const llvm::Record *> &&defs, raw_ostream &os)
|
||||
InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os)
|
||||
: defs(std::move(defs)), os(os) {}
|
||||
|
||||
void emitConceptDecl(const Interface &interface);
|
||||
@ -99,7 +100,7 @@ protected:
|
||||
void emitInterfaceDecl(const Interface &interface);
|
||||
|
||||
/// The set of interface records to emit.
|
||||
std::vector<const llvm::Record *> defs;
|
||||
std::vector<const Record *> defs;
|
||||
// The stream to emit to.
|
||||
raw_ostream &os;
|
||||
/// The C++ value type of the interface, e.g. Operation*.
|
||||
@ -118,7 +119,7 @@ protected:
|
||||
|
||||
/// A specialized generator for attribute interfaces.
|
||||
struct AttrInterfaceGenerator : public InterfaceGenerator {
|
||||
AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
||||
AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
|
||||
valueType = "::mlir::Attribute";
|
||||
interfaceBaseType = "AttributeInterface";
|
||||
@ -133,7 +134,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
|
||||
};
|
||||
/// A specialized generator for operation interfaces.
|
||||
struct OpInterfaceGenerator : public InterfaceGenerator {
|
||||
OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
||||
OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
|
||||
valueType = "::mlir::Operation *";
|
||||
interfaceBaseType = "OpInterface";
|
||||
@ -149,7 +150,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
|
||||
};
|
||||
/// A specialized generator for type interfaces.
|
||||
struct TypeInterfaceGenerator : public InterfaceGenerator {
|
||||
TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
||||
TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
|
||||
valueType = "::mlir::Type";
|
||||
interfaceBaseType = "TypeInterface";
|
||||
@ -607,13 +608,13 @@ bool InterfaceGenerator::emitInterfaceDecls() {
|
||||
llvm::emitSourceFileHeader("Interface Declarations", os);
|
||||
// Sort according to ID, so defs are emitted in the order in which they appear
|
||||
// in the Tablegen file.
|
||||
std::vector<const llvm::Record *> sortedDefs(defs);
|
||||
llvm::sort(sortedDefs, [](const llvm::Record *lhs, const llvm::Record *rhs) {
|
||||
std::vector<const Record *> sortedDefs(defs);
|
||||
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
|
||||
return lhs->getID() < rhs->getID();
|
||||
});
|
||||
for (const llvm::Record *def : sortedDefs)
|
||||
for (const Record *def : sortedDefs)
|
||||
emitInterfaceDecl(Interface(def));
|
||||
for (const llvm::Record *def : sortedDefs)
|
||||
for (const Record *def : sortedDefs)
|
||||
emitModelMethodsDef(Interface(def));
|
||||
return false;
|
||||
}
|
||||
@ -622,8 +623,7 @@ bool InterfaceGenerator::emitInterfaceDecls() {
|
||||
// GEN: Interface documentation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void emitInterfaceDoc(const llvm::Record &interfaceDef,
|
||||
raw_ostream &os) {
|
||||
static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
|
||||
Interface interface(&interfaceDef);
|
||||
|
||||
// Emit the interface name followed by the description.
|
||||
@ -684,15 +684,15 @@ struct InterfaceGenRegistration {
|
||||
genDefDesc(("Generate " + genDesc + " interface definitions").str()),
|
||||
genDocDesc(("Generate " + genDesc + " interface documentation").str()),
|
||||
genDecls(genDeclArg, genDeclDesc,
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return GeneratorT(records, os).emitInterfaceDecls();
|
||||
}),
|
||||
genDefs(genDefArg, genDefDesc,
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return GeneratorT(records, os).emitInterfaceDefs();
|
||||
}),
|
||||
genDocs(genDocArg, genDocDesc,
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return GeneratorT(records, os).emitInterfaceDocs();
|
||||
}) {}
|
||||
|
||||
|
@ -23,6 +23,9 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::formatv;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
|
||||
/// File header and includes.
|
||||
/// {0} is the dialect namespace.
|
||||
@ -315,9 +318,9 @@ static std::string sanitizeName(StringRef name) {
|
||||
}
|
||||
|
||||
static std::string attrSizedTraitForKind(const char *kind) {
|
||||
return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
|
||||
llvm::StringRef(kind).take_front().upper(),
|
||||
llvm::StringRef(kind).drop_front());
|
||||
return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
|
||||
StringRef(kind).take_front().upper(),
|
||||
StringRef(kind).drop_front());
|
||||
}
|
||||
|
||||
/// Emits accessors to "elements" of an Op definition. Currently, the supported
|
||||
@ -328,15 +331,14 @@ static void emitElementAccessors(
|
||||
unsigned numVariadicGroups, unsigned numElements,
|
||||
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
||||
getElement) {
|
||||
assert(llvm::is_contained(
|
||||
llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
|
||||
assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
|
||||
kind) &&
|
||||
"unsupported kind");
|
||||
|
||||
// Traits indicating how to process variadic elements.
|
||||
std::string sameSizeTrait =
|
||||
llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
|
||||
llvm::StringRef(kind).take_front().upper(),
|
||||
llvm::StringRef(kind).drop_front());
|
||||
std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
|
||||
StringRef(kind).take_front().upper(),
|
||||
StringRef(kind).drop_front());
|
||||
std::string attrSizedTrait = attrSizedTraitForKind(kind);
|
||||
|
||||
// If there is only one variable-length element group, its size can be
|
||||
@ -351,15 +353,14 @@ static void emitElementAccessors(
|
||||
if (element.name.empty())
|
||||
continue;
|
||||
if (element.isVariableLength()) {
|
||||
os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
|
||||
: opOneVariadicTemplate,
|
||||
sanitizeName(element.name), kind, numElements, i);
|
||||
os << formatv(element.isOptional() ? opOneOptionalTemplate
|
||||
: opOneVariadicTemplate,
|
||||
sanitizeName(element.name), kind, numElements, i);
|
||||
} else if (seenVariableLength) {
|
||||
os << llvm::formatv(opSingleAfterVariableTemplate,
|
||||
sanitizeName(element.name), kind, numElements, i);
|
||||
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
|
||||
kind, numElements, i);
|
||||
} else {
|
||||
os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
|
||||
i);
|
||||
os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
|
||||
}
|
||||
}
|
||||
return;
|
||||
@ -382,14 +383,13 @@ static void emitElementAccessors(
|
||||
for (unsigned i = 0; i < numElements; ++i) {
|
||||
const NamedTypeConstraint &element = getElement(op, i);
|
||||
if (!element.name.empty()) {
|
||||
os << llvm::formatv(opVariadicEqualPrefixTemplate,
|
||||
sanitizeName(element.name), kind, numSimpleLength,
|
||||
numVariadicGroups, numPrecedingSimple,
|
||||
numPrecedingVariadic);
|
||||
os << llvm::formatv(element.isVariableLength()
|
||||
? opVariadicEqualVariadicTemplate
|
||||
: opVariadicEqualSimpleTemplate,
|
||||
kind);
|
||||
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
|
||||
kind, numSimpleLength, numVariadicGroups,
|
||||
numPrecedingSimple, numPrecedingVariadic);
|
||||
os << formatv(element.isVariableLength()
|
||||
? opVariadicEqualVariadicTemplate
|
||||
: opVariadicEqualSimpleTemplate,
|
||||
kind);
|
||||
}
|
||||
if (element.isVariableLength())
|
||||
++numPrecedingVariadic;
|
||||
@ -412,9 +412,9 @@ static void emitElementAccessors(
|
||||
trailing = "[0]";
|
||||
else if (element.isOptional())
|
||||
trailing = std::string(
|
||||
llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
|
||||
os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
|
||||
kind, i, trailing);
|
||||
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
|
||||
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
|
||||
i, trailing);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -459,27 +459,21 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
|
||||
|
||||
// Unit attributes are handled specially.
|
||||
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
||||
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name);
|
||||
os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name);
|
||||
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (namedAttr.attr.isOptional()) {
|
||||
os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << formatv(optionalAttributeGetterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << formatv(optionalAttributeSetterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
|
||||
} else {
|
||||
os << llvm::formatv(attributeGetterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << llvm::formatv(attributeSetterTemplate, sanitizedName,
|
||||
namedAttr.name);
|
||||
os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
|
||||
os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
|
||||
// Non-optional attributes cannot be deleted.
|
||||
}
|
||||
}
|
||||
@ -595,7 +589,7 @@ static bool canInferType(const Operator &op) {
|
||||
/// accept them as arguments.
|
||||
static void
|
||||
populateBuilderArgsResults(const Operator &op,
|
||||
llvm::SmallVectorImpl<std::string> &builderArgs) {
|
||||
SmallVectorImpl<std::string> &builderArgs) {
|
||||
if (canInferType(op))
|
||||
return;
|
||||
|
||||
@ -607,7 +601,7 @@ populateBuilderArgsResults(const Operator &op,
|
||||
// to properly match the built-in result accessor.
|
||||
name = "result";
|
||||
} else {
|
||||
name = llvm::formatv("_gen_res_{0}", i);
|
||||
name = formatv("_gen_res_{0}", i);
|
||||
}
|
||||
}
|
||||
name = sanitizeName(name);
|
||||
@ -620,14 +614,13 @@ populateBuilderArgsResults(const Operator &op,
|
||||
/// appear in the `arguments` field of the op definition. Additionally,
|
||||
/// `operandNames` is populated with names of operands in their order of
|
||||
/// appearance.
|
||||
static void
|
||||
populateBuilderArgs(const Operator &op,
|
||||
llvm::SmallVectorImpl<std::string> &builderArgs,
|
||||
llvm::SmallVectorImpl<std::string> &operandNames) {
|
||||
static void populateBuilderArgs(const Operator &op,
|
||||
SmallVectorImpl<std::string> &builderArgs,
|
||||
SmallVectorImpl<std::string> &operandNames) {
|
||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
std::string name = op.getArgName(i).str();
|
||||
if (name.empty())
|
||||
name = llvm::formatv("_gen_arg_{0}", i);
|
||||
name = formatv("_gen_arg_{0}", i);
|
||||
name = sanitizeName(name);
|
||||
builderArgs.push_back(name);
|
||||
if (!op.getArg(i).is<NamedAttribute *>())
|
||||
@ -637,15 +630,16 @@ populateBuilderArgs(const Operator &op,
|
||||
|
||||
/// Populates `builderArgs` with the Python-compatible names of builder function
|
||||
/// successor arguments. Additionally, `successorArgNames` is also populated.
|
||||
static void populateBuilderArgsSuccessors(
|
||||
const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
|
||||
llvm::SmallVectorImpl<std::string> &successorArgNames) {
|
||||
static void
|
||||
populateBuilderArgsSuccessors(const Operator &op,
|
||||
SmallVectorImpl<std::string> &builderArgs,
|
||||
SmallVectorImpl<std::string> &successorArgNames) {
|
||||
|
||||
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
|
||||
NamedSuccessor successor = op.getSuccessor(i);
|
||||
std::string name = std::string(successor.name);
|
||||
if (name.empty())
|
||||
name = llvm::formatv("_gen_successor_{0}", i);
|
||||
name = formatv("_gen_successor_{0}", i);
|
||||
name = sanitizeName(name);
|
||||
builderArgs.push_back(name);
|
||||
successorArgNames.push_back(name);
|
||||
@ -658,9 +652,8 @@ static void populateBuilderArgsSuccessors(
|
||||
/// operands and attributes in the same order as they appear in the `arguments`
|
||||
/// field.
|
||||
static void
|
||||
populateBuilderLinesAttr(const Operator &op,
|
||||
llvm::ArrayRef<std::string> argNames,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
|
||||
SmallVectorImpl<std::string> &builderLines) {
|
||||
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
|
||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
Argument arg = op.getArg(i);
|
||||
@ -670,12 +663,12 @@ populateBuilderLinesAttr(const Operator &op,
|
||||
|
||||
// Unit attributes are handled specially.
|
||||
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
||||
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
|
||||
attribute->name, argNames[i]));
|
||||
builderLines.push_back(
|
||||
formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
|
||||
continue;
|
||||
}
|
||||
|
||||
builderLines.push_back(llvm::formatv(
|
||||
builderLines.push_back(formatv(
|
||||
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
|
||||
? initOptionalAttributeWithBuilderTemplate
|
||||
: initAttributeWithBuilderTemplate,
|
||||
@ -686,30 +679,30 @@ populateBuilderLinesAttr(const Operator &op,
|
||||
/// Populates `builderLines` with additional lines that are required in the
|
||||
/// builder to set up successors. successorArgNames is expected to correspond
|
||||
/// to the Python argument name for each successor on the op.
|
||||
static void populateBuilderLinesSuccessors(
|
||||
const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
static void
|
||||
populateBuilderLinesSuccessors(const Operator &op,
|
||||
ArrayRef<std::string> successorArgNames,
|
||||
SmallVectorImpl<std::string> &builderLines) {
|
||||
if (successorArgNames.empty()) {
|
||||
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
|
||||
builderLines.push_back(formatv(initSuccessorsTemplate, "None"));
|
||||
return;
|
||||
}
|
||||
|
||||
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
|
||||
builderLines.push_back(formatv(initSuccessorsTemplate, "[]"));
|
||||
for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
|
||||
auto &argName = successorArgNames[i];
|
||||
const NamedSuccessor &successor = op.getSuccessor(i);
|
||||
builderLines.push_back(
|
||||
llvm::formatv(addSuccessorTemplate,
|
||||
successor.isVariadic() ? "extend" : "append", argName));
|
||||
builderLines.push_back(formatv(addSuccessorTemplate,
|
||||
successor.isVariadic() ? "extend" : "append",
|
||||
argName));
|
||||
}
|
||||
}
|
||||
|
||||
/// Populates `builderLines` with additional lines that are required in the
|
||||
/// builder to set up op operands.
|
||||
static void
|
||||
populateBuilderLinesOperand(const Operator &op,
|
||||
llvm::ArrayRef<std::string> names,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
|
||||
SmallVectorImpl<std::string> &builderLines) {
|
||||
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
|
||||
|
||||
// For each element, find or generate a name.
|
||||
@ -718,7 +711,7 @@ populateBuilderLinesOperand(const Operator &op,
|
||||
std::string name = names[i];
|
||||
|
||||
// Choose the formatting string based on the element kind.
|
||||
llvm::StringRef formatString;
|
||||
StringRef formatString;
|
||||
if (!element.isVariableLength()) {
|
||||
formatString = singleOperandAppendTemplate;
|
||||
} else if (element.isOptional()) {
|
||||
@ -738,7 +731,7 @@ populateBuilderLinesOperand(const Operator &op,
|
||||
}
|
||||
}
|
||||
|
||||
builderLines.push_back(llvm::formatv(formatString.data(), name));
|
||||
builderLines.push_back(formatv(formatString.data(), name));
|
||||
}
|
||||
}
|
||||
|
||||
@ -758,7 +751,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
|
||||
/// Appends the given multiline string as individual strings into
|
||||
/// `builderLines`.
|
||||
static void appendLineByLine(StringRef string,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
SmallVectorImpl<std::string> &builderLines) {
|
||||
|
||||
std::pair<StringRef, StringRef> split = std::make_pair(string, string);
|
||||
do {
|
||||
@ -770,14 +763,13 @@ static void appendLineByLine(StringRef string,
|
||||
/// Populates `builderLines` with additional lines that are required in the
|
||||
/// builder to set up op results.
|
||||
static void
|
||||
populateBuilderLinesResult(const Operator &op,
|
||||
llvm::ArrayRef<std::string> names,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
|
||||
SmallVectorImpl<std::string> &builderLines) {
|
||||
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
|
||||
|
||||
if (hasSameArgumentAndResultTypes(op)) {
|
||||
builderLines.push_back(llvm::formatv(
|
||||
appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
|
||||
builderLines.push_back(formatv(appendSameResultsTemplate,
|
||||
"operands[0].type", op.getNumResults()));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -785,12 +777,11 @@ populateBuilderLinesResult(const Operator &op,
|
||||
const NamedAttribute &firstAttr = op.getAttribute(0);
|
||||
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
|
||||
"from which the type is derived");
|
||||
appendLineByLine(
|
||||
llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
|
||||
builderLines);
|
||||
builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
|
||||
"_ods_derived_result_type",
|
||||
op.getNumResults()));
|
||||
appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
|
||||
builderLines);
|
||||
builderLines.push_back(formatv(appendSameResultsTemplate,
|
||||
"_ods_derived_result_type",
|
||||
op.getNumResults()));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -803,7 +794,7 @@ populateBuilderLinesResult(const Operator &op,
|
||||
std::string name = names[i];
|
||||
|
||||
// Choose the formatting string based on the element kind.
|
||||
llvm::StringRef formatString;
|
||||
StringRef formatString;
|
||||
if (!element.isVariableLength()) {
|
||||
formatString = singleResultAppendTemplate;
|
||||
} else if (element.isOptional()) {
|
||||
@ -819,17 +810,16 @@ populateBuilderLinesResult(const Operator &op,
|
||||
}
|
||||
}
|
||||
|
||||
builderLines.push_back(llvm::formatv(formatString.data(), name));
|
||||
builderLines.push_back(formatv(formatString.data(), name));
|
||||
}
|
||||
}
|
||||
|
||||
/// If the operation has variadic regions, adds a builder argument to specify
|
||||
/// the number of those regions and builder lines to forward it to the generic
|
||||
/// constructor.
|
||||
static void
|
||||
populateBuilderRegions(const Operator &op,
|
||||
llvm::SmallVectorImpl<std::string> &builderArgs,
|
||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
||||
static void populateBuilderRegions(const Operator &op,
|
||||
SmallVectorImpl<std::string> &builderArgs,
|
||||
SmallVectorImpl<std::string> &builderLines) {
|
||||
if (op.hasNoVariadicRegions())
|
||||
return;
|
||||
|
||||
@ -844,19 +834,19 @@ populateBuilderRegions(const Operator &op,
|
||||
.str();
|
||||
builderArgs.push_back(name);
|
||||
builderLines.push_back(
|
||||
llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
|
||||
formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
|
||||
}
|
||||
|
||||
/// Emits a default builder constructing an operation from the list of its
|
||||
/// result types, followed by a list of its operands. Returns vector
|
||||
/// of fully built functionArgs for downstream users (to save having to
|
||||
/// rebuild anew).
|
||||
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
||||
raw_ostream &os) {
|
||||
llvm::SmallVector<std::string> builderArgs;
|
||||
llvm::SmallVector<std::string> builderLines;
|
||||
llvm::SmallVector<std::string> operandArgNames;
|
||||
llvm::SmallVector<std::string> successorArgNames;
|
||||
static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
||||
raw_ostream &os) {
|
||||
SmallVector<std::string> builderArgs;
|
||||
SmallVector<std::string> builderLines;
|
||||
SmallVector<std::string> operandArgNames;
|
||||
SmallVector<std::string> successorArgNames;
|
||||
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
|
||||
op.getNumNativeAttributes() + op.getNumSuccessors());
|
||||
populateBuilderArgsResults(op, builderArgs);
|
||||
@ -866,10 +856,10 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
||||
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
|
||||
|
||||
populateBuilderLinesOperand(op, operandArgNames, builderLines);
|
||||
populateBuilderLinesAttr(
|
||||
op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines);
|
||||
populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
|
||||
builderLines);
|
||||
populateBuilderLinesResult(
|
||||
op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
|
||||
op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
|
||||
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
|
||||
populateBuilderRegions(op, builderArgs, builderLines);
|
||||
|
||||
@ -896,7 +886,7 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
||||
};
|
||||
|
||||
// StringRefs in functionArgs refer to strings allocated by builderArgs.
|
||||
llvm::SmallVector<llvm::StringRef> functionArgs;
|
||||
SmallVector<StringRef> functionArgs;
|
||||
|
||||
// Add positional arguments.
|
||||
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
|
||||
@ -929,11 +919,10 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
||||
initArgs.push_back("loc=loc");
|
||||
initArgs.push_back("ip=ip");
|
||||
|
||||
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
|
||||
llvm::join(builderLines, "\n "),
|
||||
llvm::join(initArgs, ", "));
|
||||
os << formatv(initTemplate, llvm::join(functionArgs, ", "),
|
||||
llvm::join(builderLines, "\n "), llvm::join(initArgs, ", "));
|
||||
return llvm::to_vector<8>(
|
||||
llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
|
||||
llvm::map_range(functionArgs, [](StringRef s) { return s.str(); }));
|
||||
}
|
||||
|
||||
static void emitSegmentSpec(
|
||||
@ -955,15 +944,15 @@ static void emitSegmentSpec(
|
||||
}
|
||||
segmentSpec.append("]");
|
||||
|
||||
os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
|
||||
os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
|
||||
}
|
||||
|
||||
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
|
||||
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
|
||||
// Note that the base OpView class defines this as (0, True).
|
||||
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
|
||||
os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
|
||||
op.hasNoVariadicRegions() ? "True" : "False");
|
||||
os << formatv(opClassRegionSpecTemplate, minRegionCount,
|
||||
op.hasNoVariadicRegions() ? "True" : "False");
|
||||
}
|
||||
|
||||
/// Emits named accessors to regions.
|
||||
@ -975,20 +964,20 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
|
||||
|
||||
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
|
||||
"expected only the last region to be variadic");
|
||||
os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
|
||||
std::to_string(en.index()) +
|
||||
(region.isVariadic() ? ":" : ""));
|
||||
os << formatv(regionAccessorTemplate, sanitizeName(region.name),
|
||||
std::to_string(en.index()) +
|
||||
(region.isVariadic() ? ":" : ""));
|
||||
}
|
||||
}
|
||||
|
||||
/// Emits builder that extracts results from op
|
||||
static void emitValueBuilder(const Operator &op,
|
||||
llvm::SmallVector<std::string> functionArgs,
|
||||
SmallVector<std::string> functionArgs,
|
||||
raw_ostream &os) {
|
||||
// Params with (possibly) default args.
|
||||
auto valueBuilderParams =
|
||||
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
|
||||
llvm::SmallVector<llvm::StringRef> argMaybeDefault =
|
||||
SmallVector<StringRef> argMaybeDefault =
|
||||
llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
|
||||
auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
|
||||
if (argMaybeDefault.size() == 2)
|
||||
@ -1005,7 +994,7 @@ static void emitValueBuilder(const Operator &op,
|
||||
});
|
||||
std::string nameWithoutDialect =
|
||||
op.getOperationName().substr(op.getOperationName().find('.') + 1);
|
||||
os << llvm::formatv(
|
||||
os << formatv(
|
||||
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
|
||||
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
|
||||
llvm::join(opBuilderArgs, ", "),
|
||||
@ -1016,8 +1005,7 @@ static void emitValueBuilder(const Operator &op,
|
||||
|
||||
/// Emits bindings for a specific Op to the given output stream.
|
||||
static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
||||
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
|
||||
op.getOperationName());
|
||||
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName());
|
||||
|
||||
// Sized segments.
|
||||
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
|
||||
@ -1028,7 +1016,7 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
||||
}
|
||||
|
||||
emitRegionAttributes(op, os);
|
||||
llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
|
||||
SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
|
||||
emitOperandAccessors(op, os);
|
||||
emitAttributeAccessors(op, os);
|
||||
emitResultAccessors(op, os);
|
||||
@ -1039,17 +1027,17 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
||||
/// Emits bindings for the dialect specified in the command line, including file
|
||||
/// headers and utilities. Returns `false` on success to comply with Tablegen
|
||||
/// registration requirements.
|
||||
static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
|
||||
if (clDialectName.empty())
|
||||
llvm::PrintFatalError("dialect name not provided");
|
||||
|
||||
os << fileHeader;
|
||||
if (!clDialectExtensionName.empty())
|
||||
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
|
||||
os << formatv(dialectExtensionTemplate, clDialectName.getValue());
|
||||
else
|
||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
os << formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
|
||||
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||
Operator op(rec);
|
||||
if (op.getDialectName() == clDialectName.getValue())
|
||||
emitOpBindings(op, os);
|
||||
|
@ -20,6 +20,8 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::formatv;
|
||||
using llvm::RecordKeeper;
|
||||
|
||||
static llvm::cl::OptionCategory
|
||||
passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl");
|
||||
@ -56,7 +58,7 @@ const char *const fileFooter = R"(
|
||||
)";
|
||||
|
||||
/// Emit TODO
|
||||
static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
static bool emitCAPIHeader(const RecordKeeper &records, raw_ostream &os) {
|
||||
os << fileHeader;
|
||||
os << "// Registration for the entire group\n";
|
||||
os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName
|
||||
@ -64,7 +66,7 @@ static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
|
||||
Pass pass(def);
|
||||
StringRef defName = pass.getDef()->getName();
|
||||
os << llvm::formatv(passDecl, groupName, defName);
|
||||
os << formatv(passDecl, groupName, defName);
|
||||
}
|
||||
os << fileFooter;
|
||||
return false;
|
||||
@ -91,9 +93,9 @@ void mlirRegister{0}Passes(void) {{
|
||||
}
|
||||
)";
|
||||
|
||||
static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
static bool emitCAPIImpl(const RecordKeeper &records, raw_ostream &os) {
|
||||
os << "/* Autogenerated by mlir-tblgen; don't manually edit. */";
|
||||
os << llvm::formatv(passGroupRegistrationCode, groupName);
|
||||
os << formatv(passGroupRegistrationCode, groupName);
|
||||
|
||||
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
|
||||
Pass pass(def);
|
||||
@ -103,10 +105,9 @@ static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
|
||||
constructorCall = constructor.str();
|
||||
else
|
||||
constructorCall =
|
||||
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
|
||||
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
|
||||
|
||||
os << llvm::formatv(passCreateDef, groupName, defName, constructorCall);
|
||||
os << formatv(passCreateDef, groupName, defName, constructorCall);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::RecordKeeper;
|
||||
|
||||
/// Emit the documentation for the given pass.
|
||||
static void emitDoc(const Pass &pass, raw_ostream &os) {
|
||||
@ -56,7 +57,7 @@ static void emitDoc(const Pass &pass, raw_ostream &os) {
|
||||
}
|
||||
}
|
||||
|
||||
static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
static void emitDocs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||
auto passDefs = recordKeeper.getAllDerivedDefinitions("PassBase");
|
||||
|
||||
@ -74,7 +75,7 @@ static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
|
||||
static mlir::GenRegistration
|
||||
genRegister("gen-pass-doc", "Generate pass documentation",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
emitDocs(records, os);
|
||||
return false;
|
||||
});
|
||||
|
@ -21,6 +21,8 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
using llvm::formatv;
|
||||
using llvm::RecordKeeper;
|
||||
|
||||
static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
|
||||
static llvm::cl::opt<std::string>
|
||||
@ -28,7 +30,7 @@ static llvm::cl::opt<std::string>
|
||||
llvm::cl::cat(passGenCat));
|
||||
|
||||
/// Extract the list of passes from the TableGen records.
|
||||
static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) {
|
||||
static std::vector<Pass> getPasses(const RecordKeeper &recordKeeper) {
|
||||
std::vector<Pass> passes;
|
||||
|
||||
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
|
||||
@ -91,7 +93,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
|
||||
if (options.empty())
|
||||
return;
|
||||
|
||||
os << llvm::formatv("struct {0}Options {{\n", passName);
|
||||
os << formatv("struct {0}Options {{\n", passName);
|
||||
|
||||
for (const PassOption &opt : options) {
|
||||
std::string type = opt.getType().str();
|
||||
@ -99,7 +101,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
|
||||
if (opt.isListOption())
|
||||
type = "::llvm::SmallVector<" + type + ">";
|
||||
|
||||
os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
|
||||
os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName());
|
||||
|
||||
if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
|
||||
os << " = " << defaultVal;
|
||||
@ -128,9 +130,9 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
|
||||
|
||||
// Declaration of the constructor with options.
|
||||
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
|
||||
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}("
|
||||
"{0}Options options);\n",
|
||||
passName);
|
||||
os << formatv("std::unique_ptr<::mlir::Pass> create{0}("
|
||||
"{0}Options options);\n",
|
||||
passName);
|
||||
}
|
||||
|
||||
os << "#undef " << enableVarName << "\n";
|
||||
@ -147,14 +149,13 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
|
||||
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
|
||||
constructorCall = constructor.str();
|
||||
else
|
||||
constructorCall =
|
||||
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
|
||||
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
|
||||
|
||||
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
|
||||
constructorCall);
|
||||
os << formatv(passRegistrationCode, pass.getDef()->getName(),
|
||||
constructorCall);
|
||||
}
|
||||
|
||||
os << llvm::formatv(passGroupRegistrationCode, groupName);
|
||||
os << formatv(passGroupRegistrationCode, groupName);
|
||||
|
||||
for (const Pass &pass : passes)
|
||||
os << " register" << pass.getDef()->getName() << "();\n";
|
||||
@ -270,9 +271,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
||||
os.indent(2) << "::mlir::Pass::"
|
||||
<< (opt.isListOption() ? "ListOption" : "Option");
|
||||
|
||||
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
|
||||
opt.getType(), opt.getCppVariableName(),
|
||||
opt.getArgument(), opt.getDescription());
|
||||
os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
|
||||
opt.getType(), opt.getCppVariableName(), opt.getArgument(),
|
||||
opt.getDescription());
|
||||
if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
|
||||
os << ", ::llvm::cl::init(" << defaultVal << ")";
|
||||
if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
|
||||
@ -284,9 +285,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
||||
/// Emit the declarations for each of the pass statistics.
|
||||
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
|
||||
for (const PassStatistic &stat : pass.getStatistics()) {
|
||||
os << llvm::formatv(
|
||||
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
|
||||
stat.getCppVariableName(), stat.getName(), stat.getDescription());
|
||||
os << formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
|
||||
stat.getCppVariableName(), stat.getName(),
|
||||
stat.getDescription());
|
||||
}
|
||||
}
|
||||
|
||||
@ -300,11 +301,10 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
||||
os << "#ifdef " << enableVarName << "\n";
|
||||
|
||||
if (emitDefaultConstructors) {
|
||||
os << llvm::formatv(friendDefaultConstructorDeclTemplate, passName);
|
||||
os << formatv(friendDefaultConstructorDeclTemplate, passName);
|
||||
|
||||
if (emitDefaultConstructorWithOptions)
|
||||
os << llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate,
|
||||
passName);
|
||||
os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName);
|
||||
}
|
||||
|
||||
std::string dependentDialectRegistrations;
|
||||
@ -313,24 +313,23 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
||||
llvm::interleave(
|
||||
pass.getDependentDialects(), dialectsOs,
|
||||
[&](StringRef dependentDialect) {
|
||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
||||
dependentDialect);
|
||||
dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
|
||||
},
|
||||
"\n ");
|
||||
}
|
||||
|
||||
os << "namespace impl {\n";
|
||||
os << llvm::formatv(baseClassBegin, passName, pass.getBaseClass(),
|
||||
pass.getArgument(), pass.getSummary(),
|
||||
dependentDialectRegistrations);
|
||||
os << formatv(baseClassBegin, passName, pass.getBaseClass(),
|
||||
pass.getArgument(), pass.getSummary(),
|
||||
dependentDialectRegistrations);
|
||||
|
||||
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
|
||||
os.indent(2) << llvm::formatv(
|
||||
"{0}Base({0}Options options) : {0}Base() {{\n", passName);
|
||||
os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n",
|
||||
passName);
|
||||
|
||||
for (const PassOption &opt : pass.getOptions())
|
||||
os.indent(4) << llvm::formatv("{0} = std::move(options.{0});\n",
|
||||
opt.getCppVariableName());
|
||||
os.indent(4) << formatv("{0} = std::move(options.{0});\n",
|
||||
opt.getCppVariableName());
|
||||
|
||||
os.indent(2) << "}\n";
|
||||
}
|
||||
@ -344,21 +343,20 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
||||
os << "private:\n";
|
||||
|
||||
if (emitDefaultConstructors) {
|
||||
os << llvm::formatv(friendDefaultConstructorDefTemplate, passName);
|
||||
os << formatv(friendDefaultConstructorDefTemplate, passName);
|
||||
|
||||
if (!pass.getOptions().empty())
|
||||
os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate,
|
||||
passName);
|
||||
os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName);
|
||||
}
|
||||
|
||||
os << "};\n";
|
||||
os << "} // namespace impl\n";
|
||||
|
||||
if (emitDefaultConstructors) {
|
||||
os << llvm::formatv(defaultConstructorDefTemplate, passName);
|
||||
os << formatv(defaultConstructorDefTemplate, passName);
|
||||
|
||||
if (emitDefaultConstructorWithOptions)
|
||||
os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName);
|
||||
os << formatv(defaultConstructorWithOptionsDefTemplate, passName);
|
||||
}
|
||||
|
||||
os << "#undef " << enableVarName << "\n";
|
||||
@ -367,7 +365,7 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
||||
|
||||
static void emitPass(const Pass &pass, raw_ostream &os) {
|
||||
StringRef passName = pass.getDef()->getName();
|
||||
os << llvm::formatv(passHeader, passName);
|
||||
os << formatv(passHeader, passName);
|
||||
|
||||
emitPassDecls(pass, os);
|
||||
emitPassDefs(pass, os);
|
||||
@ -436,21 +434,19 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
|
||||
llvm::interleave(
|
||||
pass.getDependentDialects(), dialectsOs,
|
||||
[&](StringRef dependentDialect) {
|
||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
||||
dependentDialect);
|
||||
dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
|
||||
},
|
||||
"\n ");
|
||||
}
|
||||
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
|
||||
pass.getArgument(), pass.getSummary(),
|
||||
dependentDialectRegistrations);
|
||||
os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
|
||||
pass.getArgument(), pass.getSummary(),
|
||||
dependentDialectRegistrations);
|
||||
emitPassOptionDecls(pass, os);
|
||||
emitPassStatisticDecls(pass, os);
|
||||
os << "};\n";
|
||||
}
|
||||
|
||||
static void emitPasses(const llvm::RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
static void emitPasses(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
std::vector<Pass> passes = getPasses(recordKeeper);
|
||||
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
|
||||
|
||||
@ -479,7 +475,7 @@ static void emitPasses(const llvm::RecordKeeper &recordKeeper,
|
||||
|
||||
static mlir::GenRegistration
|
||||
genPassDecls("gen-pass-decls", "Generate pass declarations",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
emitPasses(records, os);
|
||||
return false;
|
||||
});
|
||||
|
@ -98,15 +98,15 @@ public:
|
||||
StringRef getMergeInstance() const;
|
||||
|
||||
// Returns the underlying LLVM TableGen Record.
|
||||
const llvm::Record *getDef() const { return def; }
|
||||
const Record *getDef() const { return def; }
|
||||
|
||||
private:
|
||||
// The TableGen definition of this availability.
|
||||
const llvm::Record *def;
|
||||
const Record *def;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
Availability::Availability(const llvm::Record *def) : def(def) {
|
||||
Availability::Availability(const Record *def) : def(def) {
|
||||
assert(def->isSubClassOf("Availability") &&
|
||||
"must be subclass of TableGen 'Availability' class");
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user