[NFC][MLIR][TableGen] Eliminate llvm::
for commonly used types (#110841)
Eliminate `llvm::` namespace qualifier for commonly used types in MLIR TableGen backends to reduce code clutter.
This commit is contained in:
parent
906ffc4b4a
commit
bccd37f69f
@ -22,6 +22,8 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::Record;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Utility Functions
|
// Utility Functions
|
||||||
@ -30,14 +32,14 @@ using namespace mlir::tblgen;
|
|||||||
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
|
/// Find all the AttrOrTypeDef for the specified dialect. If no dialect
|
||||||
/// specified and can only find one dialect's defs, use that.
|
/// specified and can only find one dialect's defs, use that.
|
||||||
static void collectAllDefs(StringRef selectedDialect,
|
static void collectAllDefs(StringRef selectedDialect,
|
||||||
ArrayRef<const llvm::Record *> records,
|
ArrayRef<const Record *> records,
|
||||||
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
|
SmallVectorImpl<AttrOrTypeDef> &resultDefs) {
|
||||||
// Nothing to do if no defs were found.
|
// Nothing to do if no defs were found.
|
||||||
if (records.empty())
|
if (records.empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto defs = llvm::map_range(
|
auto defs = llvm::map_range(
|
||||||
records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); });
|
records, [&](const Record *rec) { return AttrOrTypeDef(rec); });
|
||||||
if (selectedDialect.empty()) {
|
if (selectedDialect.empty()) {
|
||||||
// If a dialect was not specified, ensure that all found defs belong to the
|
// If a dialect was not specified, ensure that all found defs belong to the
|
||||||
// same dialect.
|
// same dialect.
|
||||||
@ -690,13 +692,12 @@ public:
|
|||||||
bool emitDefs(StringRef selectedDialect);
|
bool emitDefs(StringRef selectedDialect);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
DefGenerator(ArrayRef<const llvm::Record *> defs, raw_ostream &os,
|
DefGenerator(ArrayRef<const Record *> defs, raw_ostream &os,
|
||||||
StringRef defType, StringRef valueType, bool isAttrGenerator)
|
StringRef defType, StringRef valueType, bool isAttrGenerator)
|
||||||
: defRecords(defs), os(os), defType(defType), valueType(valueType),
|
: defRecords(defs), os(os), defType(defType), valueType(valueType),
|
||||||
isAttrGenerator(isAttrGenerator) {
|
isAttrGenerator(isAttrGenerator) {
|
||||||
// Sort by occurrence in file.
|
// Sort by occurrence in file.
|
||||||
llvm::sort(defRecords,
|
llvm::sort(defRecords, [](const Record *lhs, const Record *rhs) {
|
||||||
[](const llvm::Record *lhs, const llvm::Record *rhs) {
|
|
||||||
return lhs->getID() < rhs->getID();
|
return lhs->getID() < rhs->getID();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -707,7 +708,7 @@ protected:
|
|||||||
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
|
void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs);
|
||||||
|
|
||||||
/// The set of def records to emit.
|
/// The set of def records to emit.
|
||||||
std::vector<const llvm::Record *> defRecords;
|
std::vector<const Record *> defRecords;
|
||||||
/// The attribute or type class to emit.
|
/// The attribute or type class to emit.
|
||||||
/// The stream to emit to.
|
/// The stream to emit to.
|
||||||
raw_ostream &os;
|
raw_ostream &os;
|
||||||
@ -722,13 +723,13 @@ protected:
|
|||||||
|
|
||||||
/// A specialized generator for AttrDefs.
|
/// A specialized generator for AttrDefs.
|
||||||
struct AttrDefGenerator : public DefGenerator {
|
struct AttrDefGenerator : public DefGenerator {
|
||||||
AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
AttrDefGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||||
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
|
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("AttrDef"), os,
|
||||||
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
|
"Attr", "Attribute", /*isAttrGenerator=*/true) {}
|
||||||
};
|
};
|
||||||
/// A specialized generator for TypeDefs.
|
/// A specialized generator for TypeDefs.
|
||||||
struct TypeDefGenerator : public DefGenerator {
|
struct TypeDefGenerator : public DefGenerator {
|
||||||
TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
TypeDefGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||||
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
|
: DefGenerator(records.getAllDerivedDefinitionsIfDefined("TypeDef"), os,
|
||||||
"Type", "Type", /*isAttrGenerator=*/false) {}
|
"Type", "Type", /*isAttrGenerator=*/false) {}
|
||||||
};
|
};
|
||||||
@ -1030,9 +1031,9 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
|
|||||||
|
|
||||||
/// Find all type constraints for which a C++ function should be generated.
|
/// Find all type constraints for which a C++ function should be generated.
|
||||||
static std::vector<Constraint>
|
static std::vector<Constraint>
|
||||||
getAllTypeConstraints(const llvm::RecordKeeper &records) {
|
getAllTypeConstraints(const RecordKeeper &records) {
|
||||||
std::vector<Constraint> result;
|
std::vector<Constraint> result;
|
||||||
for (const llvm::Record *def :
|
for (const Record *def :
|
||||||
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
|
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
|
||||||
// Ignore constraints defined outside of the top-level file.
|
// Ignore constraints defined outside of the top-level file.
|
||||||
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
|
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
|
||||||
@ -1047,7 +1048,7 @@ getAllTypeConstraints(const llvm::RecordKeeper &records) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
|
static void emitTypeConstraintDecls(const RecordKeeper &records,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
static const char *const typeConstraintDecl = R"(
|
static const char *const typeConstraintDecl = R"(
|
||||||
bool {0}(::mlir::Type type);
|
bool {0}(::mlir::Type type);
|
||||||
@ -1057,7 +1058,7 @@ bool {0}(::mlir::Type type);
|
|||||||
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
|
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
|
||||||
}
|
}
|
||||||
|
|
||||||
static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
|
static void emitTypeConstraintDefs(const RecordKeeper &records,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
static const char *const typeConstraintDef = R"(
|
static const char *const typeConstraintDef = R"(
|
||||||
bool {0}(::mlir::Type type) {
|
bool {0}(::mlir::Type type) {
|
||||||
@ -1088,13 +1089,13 @@ static llvm::cl::opt<std::string>
|
|||||||
|
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
|
genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
AttrDefGenerator generator(records, os);
|
AttrDefGenerator generator(records, os);
|
||||||
return generator.emitDefs(attrDialect);
|
return generator.emitDefs(attrDialect);
|
||||||
});
|
});
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
|
genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
AttrDefGenerator generator(records, os);
|
AttrDefGenerator generator(records, os);
|
||||||
return generator.emitDecls(attrDialect);
|
return generator.emitDecls(attrDialect);
|
||||||
});
|
});
|
||||||
@ -1110,13 +1111,13 @@ static llvm::cl::opt<std::string>
|
|||||||
|
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
|
genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
TypeDefGenerator generator(records, os);
|
TypeDefGenerator generator(records, os);
|
||||||
return generator.emitDefs(typeDialect);
|
return generator.emitDefs(typeDialect);
|
||||||
});
|
});
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
|
genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
TypeDefGenerator generator(records, os);
|
TypeDefGenerator generator(records, os);
|
||||||
return generator.emitDecls(typeDialect);
|
return generator.emitDecls(typeDialect);
|
||||||
});
|
});
|
||||||
@ -1124,14 +1125,14 @@ static mlir::GenRegistration
|
|||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genTypeConstrDefs("gen-type-constraint-defs",
|
genTypeConstrDefs("gen-type-constraint-defs",
|
||||||
"Generate type constraint definitions",
|
"Generate type constraint definitions",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
emitTypeConstraintDefs(records, os);
|
emitTypeConstraintDefs(records, os);
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genTypeConstrDecls("gen-type-constraint-decls",
|
genTypeConstrDecls("gen-type-constraint-decls",
|
||||||
"Generate type constraint declarations",
|
"Generate type constraint declarations",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
emitTypeConstraintDecls(records, os);
|
emitTypeConstraintDecls(records, os);
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
@ -18,11 +18,10 @@
|
|||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
|
|
||||||
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
|
static cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
|
||||||
static llvm::cl::opt<std::string>
|
static cl::opt<std::string>
|
||||||
selectedBcDialect("bytecode-dialect",
|
selectedBcDialect("bytecode-dialect", cl::desc("The dialect to gen for"),
|
||||||
llvm::cl::desc("The dialect to gen for"),
|
cl::cat(dialectGenCat), cl::CommaSeparated);
|
||||||
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -306,7 +305,7 @@ void Generator::emitPrint(StringRef kind, StringRef type,
|
|||||||
auto funScope = os.scope("{\n", "}\n\n");
|
auto funScope = os.scope("{\n", "}\n\n");
|
||||||
|
|
||||||
// Check that predicates specified if multiple bytecode instances.
|
// Check that predicates specified if multiple bytecode instances.
|
||||||
for (const llvm::Record *rec : make_second_range(vec)) {
|
for (const Record *rec : make_second_range(vec)) {
|
||||||
StringRef pred = rec->getValueAsString("printerPredicate");
|
StringRef pred = rec->getValueAsString("printerPredicate");
|
||||||
if (vec.size() > 1 && pred.empty()) {
|
if (vec.size() > 1 && pred.empty()) {
|
||||||
for (auto [index, rec] : vec) {
|
for (auto [index, rec] : vec) {
|
||||||
|
@ -30,6 +30,8 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::Record;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
|
|
||||||
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
|
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-dialect-*");
|
||||||
llvm::cl::opt<std::string>
|
llvm::cl::opt<std::string>
|
||||||
@ -39,8 +41,8 @@ llvm::cl::opt<std::string>
|
|||||||
/// Utility iterator used for filtering records for a specific dialect.
|
/// Utility iterator used for filtering records for a specific dialect.
|
||||||
namespace {
|
namespace {
|
||||||
using DialectFilterIterator =
|
using DialectFilterIterator =
|
||||||
llvm::filter_iterator<ArrayRef<llvm::Record *>::iterator,
|
llvm::filter_iterator<ArrayRef<Record *>::iterator,
|
||||||
std::function<bool(const llvm::Record *)>>;
|
std::function<bool(const Record *)>>;
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static void populateDiscardableAttributes(
|
static void populateDiscardableAttributes(
|
||||||
@ -62,8 +64,8 @@ static void populateDiscardableAttributes(
|
|||||||
/// the given dialect.
|
/// the given dialect.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static iterator_range<DialectFilterIterator>
|
static iterator_range<DialectFilterIterator>
|
||||||
filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
|
filterForDialect(ArrayRef<Record *> records, Dialect &dialect) {
|
||||||
auto filterFn = [&](const llvm::Record *record) {
|
auto filterFn = [&](const Record *record) {
|
||||||
return T(record).getDialect() == dialect;
|
return T(record).getDialect() == dialect;
|
||||||
};
|
};
|
||||||
return {DialectFilterIterator(records.begin(), records.end(), filterFn),
|
return {DialectFilterIterator(records.begin(), records.end(), filterFn),
|
||||||
@ -295,7 +297,7 @@ static void emitDialectDecl(Dialect &dialect, raw_ostream &os) {
|
|||||||
<< "::" << dialect.getCppClassName() << ")\n";
|
<< "::" << dialect.getCppClassName() << ")\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
|
static bool emitDialectDecls(const RecordKeeper &recordKeeper,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
emitSourceFileHeader("Dialect Declarations", os, recordKeeper);
|
emitSourceFileHeader("Dialect Declarations", os, recordKeeper);
|
||||||
|
|
||||||
@ -340,8 +342,7 @@ static const char *const dialectDestructorStr = R"(
|
|||||||
|
|
||||||
)";
|
)";
|
||||||
|
|
||||||
static void emitDialectDef(Dialect &dialect,
|
static void emitDialectDef(Dialect &dialect, const RecordKeeper &recordKeeper,
|
||||||
const llvm::RecordKeeper &recordKeeper,
|
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
std::string cppClassName = dialect.getCppClassName();
|
std::string cppClassName = dialect.getCppClassName();
|
||||||
|
|
||||||
@ -389,8 +390,7 @@ static void emitDialectDef(Dialect &dialect,
|
|||||||
os << llvm::formatv(dialectDestructorStr, cppClassName);
|
os << llvm::formatv(dialectDestructorStr, cppClassName);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
|
static bool emitDialectDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||||
raw_ostream &os) {
|
|
||||||
emitSourceFileHeader("Dialect Definitions", os, recordKeeper);
|
emitSourceFileHeader("Dialect Definitions", os, recordKeeper);
|
||||||
|
|
||||||
auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
|
auto dialectDefs = recordKeeper.getAllDerivedDefinitions("Dialect");
|
||||||
@ -411,12 +411,12 @@ static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,
|
|||||||
|
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
|
genDialectDecls("gen-dialect-decls", "Generate dialect declarations",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
return emitDialectDecls(records, os);
|
return emitDialectDecls(records, os);
|
||||||
});
|
});
|
||||||
|
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
|
genDialectDefs("gen-dialect-defs", "Generate dialect definitions",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
return emitDialectDefs(records, os);
|
return emitDialectDefs(records, os);
|
||||||
});
|
});
|
||||||
|
@ -21,6 +21,9 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::formatv;
|
||||||
|
using llvm::Record;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
|
|
||||||
/// File header and includes.
|
/// File header and includes.
|
||||||
constexpr const char *fileHeader = R"Py(
|
constexpr const char *fileHeader = R"Py(
|
||||||
@ -42,15 +45,15 @@ static std::string makePythonEnumCaseName(StringRef name) {
|
|||||||
|
|
||||||
/// Emits the Python class for the given enum.
|
/// Emits the Python class for the given enum.
|
||||||
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
|
static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
|
||||||
os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
|
os << formatv("class {0}({1}):\n", enumAttr.getEnumClassName(),
|
||||||
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
|
enumAttr.isBitEnum() ? "IntFlag" : "IntEnum");
|
||||||
if (!enumAttr.getSummary().empty())
|
if (!enumAttr.getSummary().empty())
|
||||||
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
|
os << formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary());
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
|
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
|
||||||
os << llvm::formatv(
|
os << formatv(" {0} = {1}\n",
|
||||||
" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()),
|
makePythonEnumCaseName(enumCase.getSymbol()),
|
||||||
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
|
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
|
||||||
: "auto()");
|
: "auto()");
|
||||||
}
|
}
|
||||||
@ -58,27 +61,25 @@ static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) {
|
|||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
if (enumAttr.isBitEnum()) {
|
if (enumAttr.isBitEnum()) {
|
||||||
os << llvm::formatv(" def __iter__(self):\n"
|
os << formatv(" def __iter__(self):\n"
|
||||||
" return iter([case for case in type(self) if "
|
" return iter([case for case in type(self) if "
|
||||||
"(self & case) is case])\n");
|
"(self & case) is case])\n");
|
||||||
os << llvm::formatv(" def __len__(self):\n"
|
os << formatv(" def __len__(self):\n"
|
||||||
" return bin(self).count(\"1\")\n");
|
" return bin(self).count(\"1\")\n");
|
||||||
os << "\n";
|
os << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
os << llvm::formatv(" def __str__(self):\n");
|
os << formatv(" def __str__(self):\n");
|
||||||
if (enumAttr.isBitEnum())
|
if (enumAttr.isBitEnum())
|
||||||
os << llvm::formatv(" if len(self) > 1:\n"
|
os << formatv(" if len(self) > 1:\n"
|
||||||
" return \"{0}\".join(map(str, self))\n",
|
" return \"{0}\".join(map(str, self))\n",
|
||||||
enumAttr.getDef().getValueAsString("separator"));
|
enumAttr.getDef().getValueAsString("separator"));
|
||||||
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
|
for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) {
|
||||||
os << llvm::formatv(" if self is {0}.{1}:\n",
|
os << formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(),
|
||||||
enumAttr.getEnumClassName(),
|
|
||||||
makePythonEnumCaseName(enumCase.getSymbol()));
|
makePythonEnumCaseName(enumCase.getSymbol()));
|
||||||
os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr());
|
os << formatv(" return \"{0}\"\n", enumCase.getStr());
|
||||||
}
|
}
|
||||||
os << llvm::formatv(
|
os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
|
||||||
" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
|
|
||||||
enumAttr.getEnumClassName());
|
enumAttr.getEnumClassName());
|
||||||
os << "\n";
|
os << "\n";
|
||||||
}
|
}
|
||||||
@ -105,12 +106,10 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n",
|
os << formatv("@register_attribute_builder(\"{0}\")\n",
|
||||||
enumAttr.getAttrDefName());
|
enumAttr.getAttrDefName());
|
||||||
os << llvm::formatv("def _{0}(x, context):\n",
|
os << formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower());
|
||||||
enumAttr.getAttrDefName().lower());
|
os << formatv(" return "
|
||||||
os << llvm::formatv(
|
|
||||||
" return "
|
|
||||||
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
|
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
|
||||||
"context=context), int(x))\n\n",
|
"context=context), int(x))\n\n",
|
||||||
bitwidth);
|
bitwidth);
|
||||||
@ -123,9 +122,9 @@ static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
|
|||||||
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
|
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
|
||||||
StringRef formatString,
|
StringRef formatString,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
|
os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
|
||||||
os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower());
|
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
|
||||||
os << llvm::formatv(" return "
|
os << formatv(" return "
|
||||||
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
|
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
|
||||||
formatString);
|
formatString);
|
||||||
return false;
|
return false;
|
||||||
@ -133,16 +132,15 @@ static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
|
|||||||
|
|
||||||
/// Emits Python bindings for all enums in the record keeper. Returns
|
/// Emits Python bindings for all enums in the record keeper. Returns
|
||||||
/// `false` on success, `true` on failure.
|
/// `false` on success, `true` on failure.
|
||||||
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
|
static bool emitPythonEnums(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||||
raw_ostream &os) {
|
|
||||||
os << fileHeader;
|
os << fileHeader;
|
||||||
for (const llvm::Record *it :
|
for (const Record *it :
|
||||||
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
|
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) {
|
||||||
EnumAttr enumAttr(*it);
|
EnumAttr enumAttr(*it);
|
||||||
emitEnumClass(enumAttr, os);
|
emitEnumClass(enumAttr, os);
|
||||||
emitAttributeBuilder(enumAttr, os);
|
emitAttributeBuilder(enumAttr, os);
|
||||||
}
|
}
|
||||||
for (const llvm::Record *it :
|
for (const Record *it :
|
||||||
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
|
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
|
||||||
AttrOrTypeDef attr(&*it);
|
AttrOrTypeDef attr(&*it);
|
||||||
if (!attr.getMnemonic()) {
|
if (!attr.getMnemonic()) {
|
||||||
@ -156,11 +154,11 @@ static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
|
|||||||
if (assemblyFormat == "`<` $value `>`") {
|
if (assemblyFormat == "`<` $value `>`") {
|
||||||
emitDialectEnumAttributeBuilder(
|
emitDialectEnumAttributeBuilder(
|
||||||
attr.getName(),
|
attr.getName(),
|
||||||
llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
|
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
|
||||||
} else if (assemblyFormat == "$value") {
|
} else if (assemblyFormat == "$value") {
|
||||||
emitDialectEnumAttributeBuilder(
|
emitDialectEnumAttributeBuilder(
|
||||||
attr.getName(),
|
attr.getName(),
|
||||||
llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
|
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
|
||||||
} else {
|
} else {
|
||||||
llvm::errs()
|
llvm::errs()
|
||||||
<< "unsupported assembly format for python enum bindings generation";
|
<< "unsupported assembly format for python enum bindings generation";
|
||||||
|
@ -26,10 +26,9 @@
|
|||||||
using llvm::formatv;
|
using llvm::formatv;
|
||||||
using llvm::isDigit;
|
using llvm::isDigit;
|
||||||
using llvm::PrintFatalError;
|
using llvm::PrintFatalError;
|
||||||
using llvm::raw_ostream;
|
|
||||||
using llvm::Record;
|
using llvm::Record;
|
||||||
using llvm::RecordKeeper;
|
using llvm::RecordKeeper;
|
||||||
using llvm::StringRef;
|
using namespace mlir;
|
||||||
using mlir::tblgen::Attribute;
|
using mlir::tblgen::Attribute;
|
||||||
using mlir::tblgen::EnumAttr;
|
using mlir::tblgen::EnumAttr;
|
||||||
using mlir::tblgen::EnumAttrCase;
|
using mlir::tblgen::EnumAttrCase;
|
||||||
@ -139,7 +138,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
|
|||||||
// is not a power of two (i.e. not a single bit case) and not a known case.
|
// is not a power of two (i.e. not a single bit case) and not a known case.
|
||||||
} else if (enumAttr.isBitEnum()) {
|
} else if (enumAttr.isBitEnum()) {
|
||||||
// Process the known multi-bit cases that use valid keywords.
|
// Process the known multi-bit cases that use valid keywords.
|
||||||
llvm::SmallVector<EnumAttrCase *> validMultiBitCases;
|
SmallVector<EnumAttrCase *> validMultiBitCases;
|
||||||
for (auto [index, caseVal] : llvm::enumerate(cases)) {
|
for (auto [index, caseVal] : llvm::enumerate(cases)) {
|
||||||
uint64_t value = caseVal.getValue();
|
uint64_t value = caseVal.getValue();
|
||||||
if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
|
if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
|
||||||
@ -476,7 +475,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
|
|||||||
EnumAttr enumAttr(enumDef);
|
EnumAttr enumAttr(enumDef);
|
||||||
StringRef enumName = enumAttr.getEnumClassName();
|
StringRef enumName = enumAttr.getEnumClassName();
|
||||||
StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
|
StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
|
||||||
const llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
|
const Record *baseAttrDef = enumAttr.getBaseAttrClass();
|
||||||
Attribute baseAttr(baseAttrDef);
|
Attribute baseAttr(baseAttrDef);
|
||||||
|
|
||||||
// Emit classof method
|
// Emit classof method
|
||||||
@ -565,7 +564,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
|
|||||||
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
|
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
|
||||||
auto enumerants = enumAttr.getAllCases();
|
auto enumerants = enumAttr.getAllCases();
|
||||||
|
|
||||||
llvm::SmallVector<StringRef, 2> namespaces;
|
SmallVector<StringRef, 2> namespaces;
|
||||||
llvm::SplitString(cppNamespace, namespaces, "::");
|
llvm::SplitString(cppNamespace, namespaces, "::");
|
||||||
|
|
||||||
for (auto ns : namespaces)
|
for (auto ns : namespaces)
|
||||||
@ -656,7 +655,7 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
|
|||||||
EnumAttr enumAttr(enumDef);
|
EnumAttr enumAttr(enumDef);
|
||||||
StringRef cppNamespace = enumAttr.getCppNamespace();
|
StringRef cppNamespace = enumAttr.getCppNamespace();
|
||||||
|
|
||||||
llvm::SmallVector<StringRef, 2> namespaces;
|
SmallVector<StringRef, 2> namespaces;
|
||||||
llvm::SplitString(cppNamespace, namespaces, "::");
|
llvm::SplitString(cppNamespace, namespaces, "::");
|
||||||
|
|
||||||
for (auto ns : namespaces)
|
for (auto ns : namespaces)
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::SourceMgr;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// FormatToken
|
// FormatToken
|
||||||
@ -26,14 +27,14 @@ SMLoc FormatToken::getLoc() const {
|
|||||||
// FormatLexer
|
// FormatLexer
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc)
|
FormatLexer::FormatLexer(SourceMgr &mgr, SMLoc loc)
|
||||||
: mgr(mgr), loc(loc),
|
: mgr(mgr), loc(loc),
|
||||||
curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
|
curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
|
||||||
curPtr(curBuffer.begin()) {}
|
curPtr(curBuffer.begin()) {}
|
||||||
|
|
||||||
FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
|
FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
|
||||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
|
||||||
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
|
llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
|
||||||
"in custom assembly format for this operation");
|
"in custom assembly format for this operation");
|
||||||
return formToken(FormatToken::error, loc.getPointer());
|
return formToken(FormatToken::error, loc.getPointer());
|
||||||
}
|
}
|
||||||
@ -44,10 +45,10 @@ FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
|
|||||||
|
|
||||||
FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
|
FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
|
||||||
const Twine ¬e) {
|
const Twine ¬e) {
|
||||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
mgr.PrintMessage(loc, SourceMgr::DK_Error, msg);
|
||||||
llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
|
llvm::SrcMgr.PrintMessage(this->loc, SourceMgr::DK_Note,
|
||||||
"in custom assembly format for this operation");
|
"in custom assembly format for this operation");
|
||||||
mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
|
mgr.PrintMessage(loc, SourceMgr::DK_Note, note);
|
||||||
return formToken(FormatToken::error, loc.getPointer());
|
return formToken(FormatToken::error, loc.getPointer());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -411,8 +411,7 @@ public:
|
|||||||
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
|
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
|
||||||
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
|
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
|
||||||
// (Enum) to the corresponding LLVM API enumerant
|
// (Enum) to the corresponding LLVM API enumerant
|
||||||
static void emitOneEnumToConversion(const llvm::Record *record,
|
static void emitOneEnumToConversion(const Record *record, raw_ostream &os) {
|
||||||
raw_ostream &os) {
|
|
||||||
LLVMEnumAttr enumAttr(record);
|
LLVMEnumAttr enumAttr(record);
|
||||||
StringRef llvmClass = enumAttr.getLLVMClassName();
|
StringRef llvmClass = enumAttr.getLLVMClassName();
|
||||||
StringRef cppClassName = enumAttr.getEnumClassName();
|
StringRef cppClassName = enumAttr.getEnumClassName();
|
||||||
@ -441,8 +440,7 @@ static void emitOneEnumToConversion(const llvm::Record *record,
|
|||||||
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
|
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
|
||||||
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
|
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
|
||||||
// (Enum) to the corresponding LLVM API C-style enumerant
|
// (Enum) to the corresponding LLVM API C-style enumerant
|
||||||
static void emitOneCEnumToConversion(const llvm::Record *record,
|
static void emitOneCEnumToConversion(const Record *record, raw_ostream &os) {
|
||||||
raw_ostream &os) {
|
|
||||||
LLVMCEnumAttr enumAttr(record);
|
LLVMCEnumAttr enumAttr(record);
|
||||||
StringRef llvmClass = enumAttr.getLLVMClassName();
|
StringRef llvmClass = enumAttr.getLLVMClassName();
|
||||||
StringRef cppClassName = enumAttr.getEnumClassName();
|
StringRef cppClassName = enumAttr.getEnumClassName();
|
||||||
@ -472,8 +470,7 @@ static void emitOneCEnumToConversion(const llvm::Record *record,
|
|||||||
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
|
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
|
||||||
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
|
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
|
||||||
// LLVM dialect enum attribute (Enum).
|
// LLVM dialect enum attribute (Enum).
|
||||||
static void emitOneEnumFromConversion(const llvm::Record *record,
|
static void emitOneEnumFromConversion(const Record *record, raw_ostream &os) {
|
||||||
raw_ostream &os) {
|
|
||||||
LLVMEnumAttr enumAttr(record);
|
LLVMEnumAttr enumAttr(record);
|
||||||
StringRef llvmClass = enumAttr.getLLVMClassName();
|
StringRef llvmClass = enumAttr.getLLVMClassName();
|
||||||
StringRef cppClassName = enumAttr.getEnumClassName();
|
StringRef cppClassName = enumAttr.getEnumClassName();
|
||||||
@ -508,8 +505,7 @@ static void emitOneEnumFromConversion(const llvm::Record *record,
|
|||||||
// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
|
// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
|
||||||
// containing switch-based logic to convert from the LLVM API C-style enumerant
|
// containing switch-based logic to convert from the LLVM API C-style enumerant
|
||||||
// to MLIR LLVM dialect enum attribute (Enum).
|
// to MLIR LLVM dialect enum attribute (Enum).
|
||||||
static void emitOneCEnumFromConversion(const llvm::Record *record,
|
static void emitOneCEnumFromConversion(const Record *record, raw_ostream &os) {
|
||||||
raw_ostream &os) {
|
|
||||||
LLVMCEnumAttr enumAttr(record);
|
LLVMCEnumAttr enumAttr(record);
|
||||||
StringRef llvmClass = enumAttr.getLLVMClassName();
|
StringRef llvmClass = enumAttr.getLLVMClassName();
|
||||||
StringRef cppClassName = enumAttr.getEnumClassName();
|
StringRef cppClassName = enumAttr.getEnumClassName();
|
||||||
|
@ -24,6 +24,11 @@
|
|||||||
#include "llvm/TableGen/Record.h"
|
#include "llvm/TableGen/Record.h"
|
||||||
#include "llvm/TableGen/TableGenBackend.h"
|
#include "llvm/TableGen/TableGenBackend.h"
|
||||||
|
|
||||||
|
using llvm::Record;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
|
using llvm::Regex;
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options");
|
static llvm::cl::OptionCategory intrinsicGenCat("Intrinsics Generator Options");
|
||||||
|
|
||||||
static llvm::cl::opt<std::string>
|
static llvm::cl::opt<std::string>
|
||||||
@ -54,14 +59,14 @@ static llvm::cl::opt<std::string> aliasAnalysisRegexp(
|
|||||||
using IndicesTy = llvm::SmallBitVector;
|
using IndicesTy = llvm::SmallBitVector;
|
||||||
|
|
||||||
/// Return a CodeGen value type entry from a type record.
|
/// Return a CodeGen value type entry from a type record.
|
||||||
static llvm::MVT::SimpleValueType getValueType(const llvm::Record *rec) {
|
static llvm::MVT::SimpleValueType getValueType(const Record *rec) {
|
||||||
return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
|
return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt(
|
||||||
"Value");
|
"Value");
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the indices of the definitions in a list of definitions that
|
/// Return the indices of the definitions in a list of definitions that
|
||||||
/// represent overloadable types
|
/// represent overloadable types
|
||||||
static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record,
|
static IndicesTy getOverloadableTypeIdxs(const Record &record,
|
||||||
const char *listName) {
|
const char *listName) {
|
||||||
auto results = record.getValueAsListOfDefs(listName);
|
auto results = record.getValueAsListOfDefs(listName);
|
||||||
IndicesTy overloadedOps(results.size());
|
IndicesTy overloadedOps(results.size());
|
||||||
@ -87,13 +92,13 @@ namespace {
|
|||||||
/// the fields of the record.
|
/// the fields of the record.
|
||||||
class LLVMIntrinsic {
|
class LLVMIntrinsic {
|
||||||
public:
|
public:
|
||||||
LLVMIntrinsic(const llvm::Record &record) : record(record) {}
|
LLVMIntrinsic(const Record &record) : record(record) {}
|
||||||
|
|
||||||
/// Get the name of the operation to be used in MLIR. Uses the appropriate
|
/// Get the name of the operation to be used in MLIR. Uses the appropriate
|
||||||
/// field if not empty, constructs a name by replacing underscores with dots
|
/// field if not empty, constructs a name by replacing underscores with dots
|
||||||
/// in the record name otherwise.
|
/// in the record name otherwise.
|
||||||
std::string getOperationName() const {
|
std::string getOperationName() const {
|
||||||
llvm::StringRef name = record.getValueAsString(fieldName);
|
StringRef name = record.getValueAsString(fieldName);
|
||||||
if (!name.empty())
|
if (!name.empty())
|
||||||
return name.str();
|
return name.str();
|
||||||
|
|
||||||
@ -101,8 +106,8 @@ public:
|
|||||||
assert(name.starts_with("int_") &&
|
assert(name.starts_with("int_") &&
|
||||||
"LLVM intrinsic names are expected to start with 'int_'");
|
"LLVM intrinsic names are expected to start with 'int_'");
|
||||||
name = name.drop_front(4);
|
name = name.drop_front(4);
|
||||||
llvm::SmallVector<llvm::StringRef, 8> chunks;
|
SmallVector<StringRef, 8> chunks;
|
||||||
llvm::StringRef targetPrefix = record.getValueAsString("TargetPrefix");
|
StringRef targetPrefix = record.getValueAsString("TargetPrefix");
|
||||||
name.split(chunks, '_');
|
name.split(chunks, '_');
|
||||||
auto *chunksBegin = chunks.begin();
|
auto *chunksBegin = chunks.begin();
|
||||||
// Remove the target prefix from target specific intrinsics.
|
// Remove the target prefix from target specific intrinsics.
|
||||||
@ -119,8 +124,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the name of the record without the "intrinsic" prefix.
|
/// Get the name of the record without the "intrinsic" prefix.
|
||||||
llvm::StringRef getProperRecordName() const {
|
StringRef getProperRecordName() const {
|
||||||
llvm::StringRef name = record.getName();
|
StringRef name = record.getName();
|
||||||
assert(name.starts_with("int_") &&
|
assert(name.starts_with("int_") &&
|
||||||
"LLVM intrinsic names are expected to start with 'int_'");
|
"LLVM intrinsic names are expected to start with 'int_'");
|
||||||
return name.drop_front(4);
|
return name.drop_front(4);
|
||||||
@ -129,10 +134,9 @@ public:
|
|||||||
/// Get the number of operands.
|
/// Get the number of operands.
|
||||||
unsigned getNumOperands() const {
|
unsigned getNumOperands() const {
|
||||||
auto operands = record.getValueAsListOfDefs(fieldOperands);
|
auto operands = record.getValueAsListOfDefs(fieldOperands);
|
||||||
assert(llvm::all_of(operands,
|
assert(llvm::all_of(
|
||||||
[](const llvm::Record *r) {
|
operands,
|
||||||
return r->isSubClassOf("LLVMType");
|
[](const Record *r) { return r->isSubClassOf("LLVMType"); }) &&
|
||||||
}) &&
|
|
||||||
"expected operands to be of LLVM type");
|
"expected operands to be of LLVM type");
|
||||||
return operands.size();
|
return operands.size();
|
||||||
}
|
}
|
||||||
@ -142,7 +146,7 @@ public:
|
|||||||
/// structure type.
|
/// structure type.
|
||||||
unsigned getNumResults() const {
|
unsigned getNumResults() const {
|
||||||
auto results = record.getValueAsListOfDefs(fieldResults);
|
auto results = record.getValueAsListOfDefs(fieldResults);
|
||||||
for (const llvm::Record *r : results) {
|
for (const Record *r : results) {
|
||||||
(void)r;
|
(void)r;
|
||||||
assert(r->isSubClassOf("LLVMType") &&
|
assert(r->isSubClassOf("LLVMType") &&
|
||||||
"expected operands to be of LLVM type");
|
"expected operands to be of LLVM type");
|
||||||
@ -155,7 +159,7 @@ public:
|
|||||||
bool hasSideEffects() const {
|
bool hasSideEffects() const {
|
||||||
return llvm::none_of(
|
return llvm::none_of(
|
||||||
record.getValueAsListOfDefs(fieldTraits),
|
record.getValueAsListOfDefs(fieldTraits),
|
||||||
[](const llvm::Record *r) { return r->getName() == "IntrNoMem"; });
|
[](const Record *r) { return r->getName() == "IntrNoMem"; });
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return true if the intrinsic is commutative, i.e. has the respective
|
/// Return true if the intrinsic is commutative, i.e. has the respective
|
||||||
@ -163,7 +167,7 @@ public:
|
|||||||
bool isCommutative() const {
|
bool isCommutative() const {
|
||||||
return llvm::any_of(
|
return llvm::any_of(
|
||||||
record.getValueAsListOfDefs(fieldTraits),
|
record.getValueAsListOfDefs(fieldTraits),
|
||||||
[](const llvm::Record *r) { return r->getName() == "Commutative"; });
|
[](const Record *r) { return r->getName() == "Commutative"; });
|
||||||
}
|
}
|
||||||
|
|
||||||
IndicesTy getOverloadableOperandsIdxs() const {
|
IndicesTy getOverloadableOperandsIdxs() const {
|
||||||
@ -181,7 +185,7 @@ private:
|
|||||||
const char *fieldResults = "RetTypes";
|
const char *fieldResults = "RetTypes";
|
||||||
const char *fieldTraits = "IntrProperties";
|
const char *fieldTraits = "IntrProperties";
|
||||||
|
|
||||||
const llvm::Record &record;
|
const Record &record;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -195,27 +199,26 @@ void printBracketedRange(const Range &range, llvm::raw_ostream &os) {
|
|||||||
|
|
||||||
/// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic.
|
/// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic.
|
||||||
/// Returns true on error, false on success.
|
/// Returns true on error, false on success.
|
||||||
static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
|
static bool emitIntrinsic(const Record &record, llvm::raw_ostream &os) {
|
||||||
LLVMIntrinsic intr(record);
|
LLVMIntrinsic intr(record);
|
||||||
|
|
||||||
llvm::Regex accessGroupMatcher(accessGroupRegexp);
|
Regex accessGroupMatcher(accessGroupRegexp);
|
||||||
bool requiresAccessGroup =
|
bool requiresAccessGroup =
|
||||||
!accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
|
!accessGroupRegexp.empty() && accessGroupMatcher.match(record.getName());
|
||||||
|
|
||||||
llvm::Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
|
Regex aliasAnalysisMatcher(aliasAnalysisRegexp);
|
||||||
bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() &&
|
bool requiresAliasAnalysis = !aliasAnalysisRegexp.empty() &&
|
||||||
aliasAnalysisMatcher.match(record.getName());
|
aliasAnalysisMatcher.match(record.getName());
|
||||||
|
|
||||||
// Prepare strings for traits, if any.
|
// Prepare strings for traits, if any.
|
||||||
llvm::SmallVector<llvm::StringRef, 2> traits;
|
SmallVector<StringRef, 2> traits;
|
||||||
if (intr.isCommutative())
|
if (intr.isCommutative())
|
||||||
traits.push_back("Commutative");
|
traits.push_back("Commutative");
|
||||||
if (!intr.hasSideEffects())
|
if (!intr.hasSideEffects())
|
||||||
traits.push_back("NoMemoryEffect");
|
traits.push_back("NoMemoryEffect");
|
||||||
|
|
||||||
// Prepare strings for operands.
|
// Prepare strings for operands.
|
||||||
llvm::SmallVector<llvm::StringRef, 8> operands(intr.getNumOperands(),
|
SmallVector<StringRef, 8> operands(intr.getNumOperands(), "LLVM_Type");
|
||||||
"LLVM_Type");
|
|
||||||
if (requiresAccessGroup)
|
if (requiresAccessGroup)
|
||||||
operands.push_back(
|
operands.push_back(
|
||||||
"OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups");
|
"OptionalAttr<LLVM_AccessGroupArrayAttr>:$access_groups");
|
||||||
@ -247,14 +250,13 @@ static bool emitIntrinsic(const llvm::Record &record, llvm::raw_ostream &os) {
|
|||||||
/// Traverses the list of TableGen definitions derived from the "Intrinsic"
|
/// Traverses the list of TableGen definitions derived from the "Intrinsic"
|
||||||
/// class and generates MLIR ODS definitions for those intrinsics that have
|
/// class and generates MLIR ODS definitions for those intrinsics that have
|
||||||
/// the name matching the filter.
|
/// the name matching the filter.
|
||||||
static bool emitIntrinsics(const llvm::RecordKeeper &records,
|
static bool emitIntrinsics(const RecordKeeper &records, llvm::raw_ostream &os) {
|
||||||
llvm::raw_ostream &os) {
|
|
||||||
llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records);
|
llvm::emitSourceFileHeader("Operations for LLVM intrinsics", os, records);
|
||||||
os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n";
|
os << "include \"mlir/Dialect/LLVMIR/LLVMOpBase.td\"\n";
|
||||||
os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n";
|
os << "include \"mlir/Interfaces/SideEffectInterfaces.td\"\n\n";
|
||||||
|
|
||||||
auto defs = records.getAllDerivedDefinitions("Intrinsic");
|
auto defs = records.getAllDerivedDefinitions("Intrinsic");
|
||||||
for (const llvm::Record *r : defs) {
|
for (const Record *r : defs) {
|
||||||
if (!nameFilter.empty() && !r->getName().contains(nameFilter))
|
if (!nameFilter.empty() && !r->getName().contains(nameFilter))
|
||||||
continue;
|
continue;
|
||||||
if (emitIntrinsic(*r, os))
|
if (emitIntrinsic(*r, os))
|
||||||
|
@ -34,30 +34,30 @@
|
|||||||
#include <set>
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Commandline Options
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
static llvm::cl::OptionCategory
|
|
||||||
docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
|
|
||||||
llvm::cl::opt<std::string>
|
|
||||||
stripPrefix("strip-prefix",
|
|
||||||
llvm::cl::desc("Strip prefix of the fully qualified names"),
|
|
||||||
llvm::cl::init("::mlir::"), llvm::cl::cat(docCat));
|
|
||||||
llvm::cl::opt<bool> allowHugoSpecificFeatures(
|
|
||||||
"allow-hugo-specific-features",
|
|
||||||
llvm::cl::desc("Allows using features specific to Hugo"),
|
|
||||||
llvm::cl::init(false), llvm::cl::cat(docCat));
|
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
using mlir::tblgen::Operator;
|
using mlir::tblgen::Operator;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Commandline Options
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
static cl::OptionCategory
|
||||||
|
docCat("Options for -gen-(attrdef|typedef|enum|op|dialect)-doc");
|
||||||
|
cl::opt<std::string>
|
||||||
|
stripPrefix("strip-prefix",
|
||||||
|
cl::desc("Strip prefix of the fully qualified names"),
|
||||||
|
cl::init("::mlir::"), cl::cat(docCat));
|
||||||
|
cl::opt<bool> allowHugoSpecificFeatures(
|
||||||
|
"allow-hugo-specific-features",
|
||||||
|
cl::desc("Allows using features specific to Hugo"), cl::init(false),
|
||||||
|
cl::cat(docCat));
|
||||||
|
|
||||||
void mlir::tblgen::emitSummary(StringRef summary, raw_ostream &os) {
|
void mlir::tblgen::emitSummary(StringRef summary, raw_ostream &os) {
|
||||||
if (!summary.empty()) {
|
if (!summary.empty()) {
|
||||||
llvm::StringRef trimmed = summary.trim();
|
StringRef trimmed = summary.trim();
|
||||||
char first = std::toupper(trimmed.front());
|
char first = std::toupper(trimmed.front());
|
||||||
llvm::StringRef rest = trimmed.drop_front();
|
StringRef rest = trimmed.drop_front();
|
||||||
os << "\n_" << first << rest << "_\n\n";
|
os << "\n_" << first << rest << "_\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -152,10 +152,10 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
|
|||||||
effectName.consume_front("::");
|
effectName.consume_front("::");
|
||||||
effectName.consume_front("mlir::");
|
effectName.consume_front("mlir::");
|
||||||
std::string effectStr;
|
std::string effectStr;
|
||||||
llvm::raw_string_ostream os(effectStr);
|
raw_string_ostream os(effectStr);
|
||||||
os << effectName << "{";
|
os << effectName << "{";
|
||||||
auto list = trait.getDef().getValueAsListOfDefs("effects");
|
auto list = trait.getDef().getValueAsListOfDefs("effects");
|
||||||
llvm::interleaveComma(list, os, [&](const Record *rec) {
|
interleaveComma(list, os, [&](const Record *rec) {
|
||||||
StringRef effect = rec->getValueAsString("effect");
|
StringRef effect = rec->getValueAsString("effect");
|
||||||
effect.consume_front("::");
|
effect.consume_front("::");
|
||||||
effect.consume_front("mlir::");
|
effect.consume_front("mlir::");
|
||||||
@ -163,7 +163,7 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
|
|||||||
});
|
});
|
||||||
os << "}";
|
os << "}";
|
||||||
effects.insert(backticks(effectStr));
|
effects.insert(backticks(effectStr));
|
||||||
name.append(llvm::formatv(" ({0})", traitName).str());
|
name.append(formatv(" ({0})", traitName).str());
|
||||||
}
|
}
|
||||||
interfaces.insert(backticks(name));
|
interfaces.insert(backticks(name));
|
||||||
continue;
|
continue;
|
||||||
@ -172,15 +172,15 @@ static void emitOpTraitsDoc(const Operator &op, raw_ostream &os) {
|
|||||||
traits.insert(backticks(name));
|
traits.insert(backticks(name));
|
||||||
}
|
}
|
||||||
if (!traits.empty()) {
|
if (!traits.empty()) {
|
||||||
llvm::interleaveComma(traits, os << "\nTraits: ");
|
interleaveComma(traits, os << "\nTraits: ");
|
||||||
os << "\n";
|
os << "\n";
|
||||||
}
|
}
|
||||||
if (!interfaces.empty()) {
|
if (!interfaces.empty()) {
|
||||||
llvm::interleaveComma(interfaces, os << "\nInterfaces: ");
|
interleaveComma(interfaces, os << "\nInterfaces: ");
|
||||||
os << "\n";
|
os << "\n";
|
||||||
}
|
}
|
||||||
if (!effects.empty()) {
|
if (!effects.empty()) {
|
||||||
llvm::interleaveComma(effects, os << "\nEffects: ");
|
interleaveComma(effects, os << "\nEffects: ");
|
||||||
os << "\n";
|
os << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -196,7 +196,7 @@ static void emitOpDoc(const Operator &op, raw_ostream &os) {
|
|||||||
std::string classNameStr = op.getQualCppClassName();
|
std::string classNameStr = op.getQualCppClassName();
|
||||||
StringRef className = classNameStr;
|
StringRef className = classNameStr;
|
||||||
(void)className.consume_front(stripPrefix);
|
(void)className.consume_front(stripPrefix);
|
||||||
os << llvm::formatv("### `{0}` ({1})\n", op.getOperationName(), className);
|
os << formatv("### `{0}` ({1})\n", op.getOperationName(), className);
|
||||||
|
|
||||||
// Emit the summary, syntax, and description if present.
|
// Emit the summary, syntax, and description if present.
|
||||||
if (op.hasSummary())
|
if (op.hasSummary())
|
||||||
@ -287,7 +287,7 @@ static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||||||
|
|
||||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||||
emitSourceLink(recordKeeper.getInputFilename(), os);
|
emitSourceLink(recordKeeper.getInputFilename(), os);
|
||||||
for (const llvm::Record *opDef : opDefs)
|
for (const Record *opDef : opDefs)
|
||||||
emitOpDoc(Operator(opDef), os);
|
emitOpDoc(Operator(opDef), os);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -339,7 +339,7 @@ static void emitAttrOrTypeDefAssemblyFormat(const AttrOrTypeDef &def,
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
|
static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
|
||||||
os << llvm::formatv("### {0}\n", def.getCppClassName());
|
os << formatv("### {0}\n", def.getCppClassName());
|
||||||
|
|
||||||
// Emit the summary if present.
|
// Emit the summary if present.
|
||||||
if (def.hasSummary())
|
if (def.hasSummary())
|
||||||
@ -376,7 +376,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
|
|||||||
auto defs = recordKeeper.getAllDerivedDefinitions(recordTypeName);
|
auto defs = recordKeeper.getAllDerivedDefinitions(recordTypeName);
|
||||||
|
|
||||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||||
for (const llvm::Record *def : defs)
|
for (const Record *def : defs)
|
||||||
emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os);
|
emitAttrOrTypeDefDoc(AttrOrTypeDef(def), os);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,7 +385,7 @@ static void emitAttrOrTypeDefDoc(const RecordKeeper &recordKeeper,
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
|
static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
|
||||||
os << llvm::formatv("### {0}\n", def.getEnumClassName());
|
os << formatv("### {0}\n", def.getEnumClassName());
|
||||||
|
|
||||||
// Emit the summary if present.
|
// Emit the summary if present.
|
||||||
if (!def.getSummary().empty())
|
if (!def.getSummary().empty())
|
||||||
@ -406,8 +406,7 @@ static void emitEnumDoc(const EnumAttr &def, raw_ostream &os) {
|
|||||||
|
|
||||||
static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
static void emitEnumDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||||
for (const llvm::Record *def :
|
for (const Record *def : recordKeeper.getAllDerivedDefinitions("EnumAttr"))
|
||||||
recordKeeper.getAllDerivedDefinitions("EnumAttr"))
|
|
||||||
emitEnumDoc(EnumAttr(def), os);
|
emitEnumDoc(EnumAttr(def), os);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -431,7 +430,7 @@ struct OpDocGroup {
|
|||||||
static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
|
static void maybeNest(bool nest, llvm::function_ref<void(raw_ostream &os)> fn,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
std::string str;
|
std::string str;
|
||||||
llvm::raw_string_ostream ss(str);
|
raw_string_ostream ss(str);
|
||||||
fn(ss);
|
fn(ss);
|
||||||
for (StringRef x : llvm::split(str, "\n")) {
|
for (StringRef x : llvm::split(str, "\n")) {
|
||||||
if (nest && x.starts_with("#"))
|
if (nest && x.starts_with("#"))
|
||||||
@ -507,7 +506,7 @@ static void emitDialectDoc(const Dialect &dialect, StringRef inputFilename,
|
|||||||
emitIfNotEmpty(dialect.getDescription(), os);
|
emitIfNotEmpty(dialect.getDescription(), os);
|
||||||
|
|
||||||
// Generate a TOC marker except if description already contains one.
|
// Generate a TOC marker except if description already contains one.
|
||||||
llvm::Regex r("^[[:space:]]*\\[TOC\\]$", llvm::Regex::RegexFlags::Newline);
|
Regex r("^[[:space:]]*\\[TOC\\]$", Regex::RegexFlags::Newline);
|
||||||
if (!r.match(dialect.getDescription()))
|
if (!r.match(dialect.getDescription()))
|
||||||
os << "[TOC]\n\n";
|
os << "[TOC]\n\n";
|
||||||
|
|
||||||
@ -537,17 +536,15 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||||||
std::vector<TypeDef> dialectTypeDefs;
|
std::vector<TypeDef> dialectTypeDefs;
|
||||||
std::vector<EnumAttr> dialectEnums;
|
std::vector<EnumAttr> dialectEnums;
|
||||||
|
|
||||||
llvm::SmallDenseSet<const Record *> seen;
|
SmallDenseSet<const Record *> seen;
|
||||||
auto addIfNotSeen = [&](const llvm::Record *record, const auto &def,
|
auto addIfNotSeen = [&](const Record *record, const auto &def, auto &vec) {
|
||||||
auto &vec) {
|
|
||||||
if (seen.insert(record).second) {
|
if (seen.insert(record).second) {
|
||||||
vec.push_back(def);
|
vec.push_back(def);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
};
|
};
|
||||||
auto addIfInDialect = [&](const llvm::Record *record, const auto &def,
|
auto addIfInDialect = [&](const Record *record, const auto &def, auto &vec) {
|
||||||
auto &vec) {
|
|
||||||
return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
|
return def.getDialect() == *dialect && addIfNotSeen(record, def, vec);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -28,6 +28,9 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::formatv;
|
||||||
|
using llvm::Record;
|
||||||
|
using llvm::StringMap;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// VariableElement
|
// VariableElement
|
||||||
@ -404,7 +407,7 @@ struct OperationFormat {
|
|||||||
StringRef opCppClassName;
|
StringRef opCppClassName;
|
||||||
|
|
||||||
/// A map of buildable types to indices.
|
/// A map of buildable types to indices.
|
||||||
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
|
llvm::MapVector<StringRef, int, StringMap<int>> buildableTypes;
|
||||||
|
|
||||||
/// The index of the buildable type, if valid, for every operand and result.
|
/// The index of the buildable type, if valid, for every operand and result.
|
||||||
std::vector<TypeResolution> operandTypes, resultTypes;
|
std::vector<TypeResolution> operandTypes, resultTypes;
|
||||||
@ -891,8 +894,7 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
|||||||
|
|
||||||
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
|
} else if (auto *attr = dyn_cast<AttributeVariable>(element)) {
|
||||||
const NamedAttribute *var = attr->getVar();
|
const NamedAttribute *var = attr->getVar();
|
||||||
body << llvm::formatv(" {0} {1}Attr;\n", var->attr.getStorageType(),
|
body << formatv(" {0} {1}Attr;\n", var->attr.getStorageType(), var->name);
|
||||||
var->name);
|
|
||||||
|
|
||||||
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
|
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
|
||||||
StringRef name = operand->getVar()->name;
|
StringRef name = operand->getVar()->name;
|
||||||
@ -910,19 +912,19 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
|||||||
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
|
<< " ::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> "
|
||||||
<< name << "Operands(&" << name << "RawOperand, 1);";
|
<< name << "Operands(&" << name << "RawOperand, 1);";
|
||||||
}
|
}
|
||||||
body << llvm::formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
|
body << formatv(" ::llvm::SMLoc {0}OperandsLoc;\n"
|
||||||
" (void){0}OperandsLoc;\n",
|
" (void){0}OperandsLoc;\n",
|
||||||
name);
|
name);
|
||||||
|
|
||||||
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
||||||
StringRef name = region->getVar()->name;
|
StringRef name = region->getVar()->name;
|
||||||
if (region->getVar()->isVariadic()) {
|
if (region->getVar()->isVariadic()) {
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
|
" ::llvm::SmallVector<std::unique_ptr<::mlir::Region>, 2> "
|
||||||
"{0}Regions;\n",
|
"{0}Regions;\n",
|
||||||
name);
|
name);
|
||||||
} else {
|
} else {
|
||||||
body << llvm::formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
|
body << formatv(" std::unique_ptr<::mlir::Region> {0}Region = "
|
||||||
"std::make_unique<::mlir::Region>();\n",
|
"std::make_unique<::mlir::Region>();\n",
|
||||||
name);
|
name);
|
||||||
}
|
}
|
||||||
@ -930,11 +932,11 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
|||||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
||||||
StringRef name = successor->getVar()->name;
|
StringRef name = successor->getVar()->name;
|
||||||
if (successor->getVar()->isVariadic()) {
|
if (successor->getVar()->isVariadic()) {
|
||||||
body << llvm::formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
|
body << formatv(" ::llvm::SmallVector<::mlir::Block *, 2> "
|
||||||
"{0}Successors;\n",
|
"{0}Successors;\n",
|
||||||
name);
|
name);
|
||||||
} else {
|
} else {
|
||||||
body << llvm::formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
|
body << formatv(" ::mlir::Block *{0}Successor = nullptr;\n", name);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||||
@ -944,8 +946,8 @@ static void genElementParserStorage(FormatElement *element, const Operator &op,
|
|||||||
body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
|
body << " ::llvm::SmallVector<::mlir::Type, 1> " << name << "Types;\n";
|
||||||
else
|
else
|
||||||
body
|
body
|
||||||
<< llvm::formatv(" ::mlir::Type {0}RawType{{};\n", name)
|
<< formatv(" ::mlir::Type {0}RawType{{};\n", name)
|
||||||
<< llvm::formatv(
|
<< formatv(
|
||||||
" ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
|
" ::llvm::ArrayRef<::mlir::Type> {0}Types(&{0}RawType, 1);\n",
|
||||||
name);
|
name);
|
||||||
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
|
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
|
||||||
@ -969,27 +971,27 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
|
|||||||
StringRef name = operand->getVar()->name;
|
StringRef name = operand->getVar()->name;
|
||||||
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
||||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
||||||
body << llvm::formatv("{0}OperandGroups", name);
|
body << formatv("{0}OperandGroups", name);
|
||||||
else if (lengthKind == ArgumentLengthKind::Variadic)
|
else if (lengthKind == ArgumentLengthKind::Variadic)
|
||||||
body << llvm::formatv("{0}Operands", name);
|
body << formatv("{0}Operands", name);
|
||||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||||
body << llvm::formatv("{0}Operand", name);
|
body << formatv("{0}Operand", name);
|
||||||
else
|
else
|
||||||
body << formatv("{0}RawOperand", name);
|
body << formatv("{0}RawOperand", name);
|
||||||
|
|
||||||
} else if (auto *region = dyn_cast<RegionVariable>(param)) {
|
} else if (auto *region = dyn_cast<RegionVariable>(param)) {
|
||||||
StringRef name = region->getVar()->name;
|
StringRef name = region->getVar()->name;
|
||||||
if (region->getVar()->isVariadic())
|
if (region->getVar()->isVariadic())
|
||||||
body << llvm::formatv("{0}Regions", name);
|
body << formatv("{0}Regions", name);
|
||||||
else
|
else
|
||||||
body << llvm::formatv("*{0}Region", name);
|
body << formatv("*{0}Region", name);
|
||||||
|
|
||||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
|
} else if (auto *successor = dyn_cast<SuccessorVariable>(param)) {
|
||||||
StringRef name = successor->getVar()->name;
|
StringRef name = successor->getVar()->name;
|
||||||
if (successor->getVar()->isVariadic())
|
if (successor->getVar()->isVariadic())
|
||||||
body << llvm::formatv("{0}Successors", name);
|
body << formatv("{0}Successors", name);
|
||||||
else
|
else
|
||||||
body << llvm::formatv("{0}Successor", name);
|
body << formatv("{0}Successor", name);
|
||||||
|
|
||||||
} else if (auto *dir = dyn_cast<RefDirective>(param)) {
|
} else if (auto *dir = dyn_cast<RefDirective>(param)) {
|
||||||
genCustomParameterParser(dir->getArg(), body);
|
genCustomParameterParser(dir->getArg(), body);
|
||||||
@ -998,11 +1000,11 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
|
|||||||
ArgumentLengthKind lengthKind;
|
ArgumentLengthKind lengthKind;
|
||||||
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
||||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
||||||
body << llvm::formatv("{0}TypeGroups", listName);
|
body << formatv("{0}TypeGroups", listName);
|
||||||
else if (lengthKind == ArgumentLengthKind::Variadic)
|
else if (lengthKind == ArgumentLengthKind::Variadic)
|
||||||
body << llvm::formatv("{0}Types", listName);
|
body << formatv("{0}Types", listName);
|
||||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||||
body << llvm::formatv("{0}Type", listName);
|
body << formatv("{0}Type", listName);
|
||||||
else
|
else
|
||||||
body << formatv("{0}RawType", listName);
|
body << formatv("{0}RawType", listName);
|
||||||
|
|
||||||
@ -1013,7 +1015,7 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
|
|||||||
body << tgfmt(string->getValue(), &ctx);
|
body << tgfmt(string->getValue(), &ctx);
|
||||||
|
|
||||||
} else if (auto *property = dyn_cast<PropertyVariable>(param)) {
|
} else if (auto *property = dyn_cast<PropertyVariable>(param)) {
|
||||||
body << llvm::formatv("result.getOrAddProperties<Properties>().{0}",
|
body << formatv("result.getOrAddProperties<Properties>().{0}",
|
||||||
property->getVar()->name);
|
property->getVar()->name);
|
||||||
} else {
|
} else {
|
||||||
llvm_unreachable("unknown custom directive parameter");
|
llvm_unreachable("unknown custom directive parameter");
|
||||||
@ -1037,12 +1039,12 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
|||||||
body << " " << var->name
|
body << " " << var->name
|
||||||
<< "OperandsLoc = parser.getCurrentLocation();\n";
|
<< "OperandsLoc = parser.getCurrentLocation();\n";
|
||||||
if (var->isOptional()) {
|
if (var->isOptional()) {
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
|
" ::std::optional<::mlir::OpAsmParser::UnresolvedOperand> "
|
||||||
"{0}Operand;\n",
|
"{0}Operand;\n",
|
||||||
var->name);
|
var->name);
|
||||||
} else if (var->isVariadicOfVariadic()) {
|
} else if (var->isVariadicOfVariadic()) {
|
||||||
body << llvm::formatv(" "
|
body << formatv(" "
|
||||||
"::llvm::SmallVector<::llvm::SmallVector<::mlir::"
|
"::llvm::SmallVector<::llvm::SmallVector<::mlir::"
|
||||||
"OpAsmParser::UnresolvedOperand>> "
|
"OpAsmParser::UnresolvedOperand>> "
|
||||||
"{0}OperandGroups;\n",
|
"{0}OperandGroups;\n",
|
||||||
@ -1052,9 +1054,9 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
|||||||
ArgumentLengthKind lengthKind;
|
ArgumentLengthKind lengthKind;
|
||||||
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
||||||
if (lengthKind == ArgumentLengthKind::Optional) {
|
if (lengthKind == ArgumentLengthKind::Optional) {
|
||||||
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
|
body << formatv(" ::mlir::Type {0}Type;\n", listName);
|
||||||
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
|
" ::llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
|
||||||
"{0}TypeGroups;\n",
|
"{0}TypeGroups;\n",
|
||||||
listName);
|
listName);
|
||||||
@ -1064,7 +1066,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
|||||||
if (auto *operand = dyn_cast<OperandVariable>(input)) {
|
if (auto *operand = dyn_cast<OperandVariable>(input)) {
|
||||||
if (!operand->getVar()->isOptional())
|
if (!operand->getVar()->isOptional())
|
||||||
continue;
|
continue;
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" {0} {1}Operand = {1}Operands.empty() ? {0}() : "
|
" {0} {1}Operand = {1}Operands.empty() ? {0}() : "
|
||||||
"{1}Operands[0];\n",
|
"{1}Operands[0];\n",
|
||||||
"::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
|
"::std::optional<::mlir::OpAsmParser::UnresolvedOperand>",
|
||||||
@ -1074,7 +1076,7 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
|||||||
ArgumentLengthKind lengthKind;
|
ArgumentLengthKind lengthKind;
|
||||||
StringRef listName = getTypeListName(type->getArg(), lengthKind);
|
StringRef listName = getTypeListName(type->getArg(), lengthKind);
|
||||||
if (lengthKind == ArgumentLengthKind::Optional) {
|
if (lengthKind == ArgumentLengthKind::Optional) {
|
||||||
body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
|
body << formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? "
|
||||||
"::mlir::Type() : {0}Types[0];\n",
|
"::mlir::Type() : {0}Types[0];\n",
|
||||||
listName);
|
listName);
|
||||||
}
|
}
|
||||||
@ -1101,23 +1103,23 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
|||||||
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
|
if (auto *attr = dyn_cast<AttributeVariable>(param)) {
|
||||||
const NamedAttribute *var = attr->getVar();
|
const NamedAttribute *var = attr->getVar();
|
||||||
if (var->attr.isOptional() || var->attr.hasDefaultValue())
|
if (var->attr.isOptional() || var->attr.hasDefaultValue())
|
||||||
body << llvm::formatv(" if ({0}Attr)\n ", var->name);
|
body << formatv(" if ({0}Attr)\n ", var->name);
|
||||||
if (useProperties) {
|
if (useProperties) {
|
||||||
body << formatv(
|
body << formatv(
|
||||||
" result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
|
" result.getOrAddProperties<{1}::Properties>().{0} = {0}Attr;\n",
|
||||||
var->name, opCppClassName);
|
var->name, opCppClassName);
|
||||||
} else {
|
} else {
|
||||||
body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
|
body << formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",
|
||||||
var->name);
|
var->name);
|
||||||
}
|
}
|
||||||
} else if (auto *operand = dyn_cast<OperandVariable>(param)) {
|
} else if (auto *operand = dyn_cast<OperandVariable>(param)) {
|
||||||
const NamedTypeConstraint *var = operand->getVar();
|
const NamedTypeConstraint *var = operand->getVar();
|
||||||
if (var->isOptional()) {
|
if (var->isOptional()) {
|
||||||
body << llvm::formatv(" if ({0}Operand.has_value())\n"
|
body << formatv(" if ({0}Operand.has_value())\n"
|
||||||
" {0}Operands.push_back(*{0}Operand);\n",
|
" {0}Operands.push_back(*{0}Operand);\n",
|
||||||
var->name);
|
var->name);
|
||||||
} else if (var->isVariadicOfVariadic()) {
|
} else if (var->isVariadicOfVariadic()) {
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" for (const auto &subRange : {0}OperandGroups) {{\n"
|
" for (const auto &subRange : {0}OperandGroups) {{\n"
|
||||||
" {0}Operands.append(subRange.begin(), subRange.end());\n"
|
" {0}Operands.append(subRange.begin(), subRange.end());\n"
|
||||||
" {0}OperandGroupSizes.push_back(subRange.size());\n"
|
" {0}OperandGroupSizes.push_back(subRange.size());\n"
|
||||||
@ -1128,11 +1130,11 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
|
|||||||
ArgumentLengthKind lengthKind;
|
ArgumentLengthKind lengthKind;
|
||||||
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
||||||
if (lengthKind == ArgumentLengthKind::Optional) {
|
if (lengthKind == ArgumentLengthKind::Optional) {
|
||||||
body << llvm::formatv(" if ({0}Type)\n"
|
body << formatv(" if ({0}Type)\n"
|
||||||
" {0}Types.push_back({0}Type);\n",
|
" {0}Types.push_back({0}Type);\n",
|
||||||
listName);
|
listName);
|
||||||
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" for (const auto &subRange : {0}TypeGroups)\n"
|
" for (const auto &subRange : {0}TypeGroups)\n"
|
||||||
" {0}Types.append(subRange.begin(), subRange.end());\n",
|
" {0}Types.append(subRange.begin(), subRange.end());\n",
|
||||||
listName);
|
listName);
|
||||||
@ -1460,7 +1462,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
|||||||
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
|
body << " if (" << attrVar->getVar()->name << "Attr) {\n";
|
||||||
} else if (auto *propVar = dyn_cast<PropertyVariable>(firstElement)) {
|
} else if (auto *propVar = dyn_cast<PropertyVariable>(firstElement)) {
|
||||||
genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false);
|
genPropertyParser(propVar, body, opCppClassName, /*requireParse=*/false);
|
||||||
body << llvm::formatv("if ({0}PropParseResult.has_value() && "
|
body << formatv("if ({0}PropParseResult.has_value() && "
|
||||||
"succeeded(*{0}PropParseResult)) ",
|
"succeeded(*{0}PropParseResult)) ",
|
||||||
propVar->getVar()->name)
|
propVar->getVar()->name)
|
||||||
<< " {\n";
|
<< " {\n";
|
||||||
@ -1477,13 +1479,12 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
|||||||
genElementParser(regionVar, body, attrTypeCtx);
|
genElementParser(regionVar, body, attrTypeCtx);
|
||||||
body << " if (!" << region->name << "Regions.empty()) {\n";
|
body << " if (!" << region->name << "Regions.empty()) {\n";
|
||||||
} else {
|
} else {
|
||||||
body << llvm::formatv(optionalRegionParserCode, region->name);
|
body << formatv(optionalRegionParserCode, region->name);
|
||||||
body << " if (!" << region->name << "Region->empty()) {\n ";
|
body << " if (!" << region->name << "Region->empty()) {\n ";
|
||||||
if (hasImplicitTermTrait)
|
if (hasImplicitTermTrait)
|
||||||
body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
|
body << formatv(regionEnsureTerminatorParserCode, region->name);
|
||||||
else if (hasSingleBlockTrait)
|
else if (hasSingleBlockTrait)
|
||||||
body << llvm::formatv(regionEnsureSingleBlockParserCode,
|
body << formatv(regionEnsureSingleBlockParserCode, region->name);
|
||||||
region->name);
|
|
||||||
}
|
}
|
||||||
} else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
|
} else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
|
||||||
body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
|
body << " if (auto optResult = [&]() -> ::mlir::OptionalParseResult {\n";
|
||||||
@ -1575,24 +1576,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
|||||||
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
|
||||||
StringRef name = operand->getVar()->name;
|
StringRef name = operand->getVar()->name;
|
||||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
|
||||||
body << llvm::formatv(variadicOfVariadicOperandParserCode, name);
|
body << formatv(variadicOfVariadicOperandParserCode, name);
|
||||||
else if (lengthKind == ArgumentLengthKind::Variadic)
|
else if (lengthKind == ArgumentLengthKind::Variadic)
|
||||||
body << llvm::formatv(variadicOperandParserCode, name);
|
body << formatv(variadicOperandParserCode, name);
|
||||||
else if (lengthKind == ArgumentLengthKind::Optional)
|
else if (lengthKind == ArgumentLengthKind::Optional)
|
||||||
body << llvm::formatv(optionalOperandParserCode, name);
|
body << formatv(optionalOperandParserCode, name);
|
||||||
else
|
else
|
||||||
body << formatv(operandParserCode, name);
|
body << formatv(operandParserCode, name);
|
||||||
|
|
||||||
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
} else if (auto *region = dyn_cast<RegionVariable>(element)) {
|
||||||
bool isVariadic = region->getVar()->isVariadic();
|
bool isVariadic = region->getVar()->isVariadic();
|
||||||
body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
|
body << formatv(isVariadic ? regionListParserCode : regionParserCode,
|
||||||
region->getVar()->name);
|
region->getVar()->name);
|
||||||
if (hasImplicitTermTrait)
|
if (hasImplicitTermTrait)
|
||||||
body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
|
body << formatv(isVariadic ? regionListEnsureTerminatorParserCode
|
||||||
: regionEnsureTerminatorParserCode,
|
: regionEnsureTerminatorParserCode,
|
||||||
region->getVar()->name);
|
region->getVar()->name);
|
||||||
else if (hasSingleBlockTrait)
|
else if (hasSingleBlockTrait)
|
||||||
body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
|
body << formatv(isVariadic ? regionListEnsureSingleBlockParserCode
|
||||||
: regionEnsureSingleBlockParserCode,
|
: regionEnsureSingleBlockParserCode,
|
||||||
region->getVar()->name);
|
region->getVar()->name);
|
||||||
|
|
||||||
@ -1631,24 +1632,24 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
|
|||||||
<< " return ::mlir::failure();\n";
|
<< " return ::mlir::failure();\n";
|
||||||
|
|
||||||
} else if (isa<RegionsDirective>(element)) {
|
} else if (isa<RegionsDirective>(element)) {
|
||||||
body << llvm::formatv(regionListParserCode, "full");
|
body << formatv(regionListParserCode, "full");
|
||||||
if (hasImplicitTermTrait)
|
if (hasImplicitTermTrait)
|
||||||
body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
|
body << formatv(regionListEnsureTerminatorParserCode, "full");
|
||||||
else if (hasSingleBlockTrait)
|
else if (hasSingleBlockTrait)
|
||||||
body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
|
body << formatv(regionListEnsureSingleBlockParserCode, "full");
|
||||||
|
|
||||||
} else if (isa<SuccessorsDirective>(element)) {
|
} else if (isa<SuccessorsDirective>(element)) {
|
||||||
body << llvm::formatv(successorListParserCode, "full");
|
body << formatv(successorListParserCode, "full");
|
||||||
|
|
||||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||||
ArgumentLengthKind lengthKind;
|
ArgumentLengthKind lengthKind;
|
||||||
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
StringRef listName = getTypeListName(dir->getArg(), lengthKind);
|
||||||
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
|
||||||
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
|
body << formatv(variadicOfVariadicTypeParserCode, listName);
|
||||||
} else if (lengthKind == ArgumentLengthKind::Variadic) {
|
} else if (lengthKind == ArgumentLengthKind::Variadic) {
|
||||||
body << llvm::formatv(variadicTypeParserCode, listName);
|
body << formatv(variadicTypeParserCode, listName);
|
||||||
} else if (lengthKind == ArgumentLengthKind::Optional) {
|
} else if (lengthKind == ArgumentLengthKind::Optional) {
|
||||||
body << llvm::formatv(optionalTypeParserCode, listName);
|
body << formatv(optionalTypeParserCode, listName);
|
||||||
} else {
|
} else {
|
||||||
const char *parserCode =
|
const char *parserCode =
|
||||||
dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
|
dir->shouldBeQualified() ? qualifiedTypeParserCode : typeParserCode;
|
||||||
@ -1903,14 +1904,14 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
|
|||||||
if (!operand.isVariadicOfVariadic())
|
if (!operand.isVariadicOfVariadic())
|
||||||
continue;
|
continue;
|
||||||
if (op.getDialect().usePropertiesForAttributes()) {
|
if (op.getDialect().usePropertiesForAttributes()) {
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" result.getOrAddProperties<{0}::Properties>().{1} = "
|
" result.getOrAddProperties<{0}::Properties>().{1} = "
|
||||||
"parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
|
"parser.getBuilder().getDenseI32ArrayAttr({2}OperandGroupSizes);\n",
|
||||||
op.getCppClassName(),
|
op.getCppClassName(),
|
||||||
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
|
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
|
||||||
operand.name);
|
operand.name);
|
||||||
} else {
|
} else {
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" result.addAttribute(\"{0}\", "
|
" result.addAttribute(\"{0}\", "
|
||||||
"parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
|
"parser.getBuilder().getDenseI32ArrayAttr({1}OperandGroupSizes));"
|
||||||
"\n",
|
"\n",
|
||||||
@ -2160,7 +2161,7 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
|
|||||||
if (var->isVariadic())
|
if (var->isVariadic())
|
||||||
body << name << "().getTypes()";
|
body << name << "().getTypes()";
|
||||||
else if (var->isOptional())
|
else if (var->isOptional())
|
||||||
body << llvm::formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
|
body << formatv("({0}() ? {0}().getType() : ::mlir::Type())", name);
|
||||||
else
|
else
|
||||||
body << name << "().getType()";
|
body << name << "().getType()";
|
||||||
|
|
||||||
@ -2195,8 +2196,7 @@ static void genCustomDirectivePrinter(CustomDirective *customDir,
|
|||||||
static void genRegionPrinter(const Twine ®ionName, MethodBody &body,
|
static void genRegionPrinter(const Twine ®ionName, MethodBody &body,
|
||||||
bool hasImplicitTermTrait) {
|
bool hasImplicitTermTrait) {
|
||||||
if (hasImplicitTermTrait)
|
if (hasImplicitTermTrait)
|
||||||
body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode,
|
body << formatv(regionSingleBlockImplicitTerminatorPrinterCode, regionName);
|
||||||
regionName);
|
|
||||||
else
|
else
|
||||||
body << " _odsPrinter.printRegion(" << regionName << ");\n";
|
body << " _odsPrinter.printRegion(" << regionName << ");\n";
|
||||||
}
|
}
|
||||||
@ -2220,12 +2220,12 @@ static MethodBody &genTypeOperandPrinter(FormatElement *arg, const Operator &op,
|
|||||||
auto *operand = dyn_cast<OperandVariable>(arg);
|
auto *operand = dyn_cast<OperandVariable>(arg);
|
||||||
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
|
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
|
||||||
if (var->isVariadicOfVariadic())
|
if (var->isVariadicOfVariadic())
|
||||||
return body << llvm::formatv("{0}().join().getTypes()",
|
return body << formatv("{0}().join().getTypes()",
|
||||||
op.getGetterName(var->name));
|
op.getGetterName(var->name));
|
||||||
if (var->isVariadic())
|
if (var->isVariadic())
|
||||||
return body << op.getGetterName(var->name) << "().getTypes()";
|
return body << op.getGetterName(var->name) << "().getTypes()";
|
||||||
if (var->isOptional())
|
if (var->isOptional())
|
||||||
return body << llvm::formatv(
|
return body << formatv(
|
||||||
"({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
|
"({0}() ? ::llvm::ArrayRef<::mlir::Type>({0}().getType()) : "
|
||||||
"::llvm::ArrayRef<::mlir::Type>())",
|
"::llvm::ArrayRef<::mlir::Type>())",
|
||||||
op.getGetterName(var->name));
|
op.getGetterName(var->name));
|
||||||
@ -2242,7 +2242,7 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
|||||||
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
|
const EnumAttr &enumAttr = cast<EnumAttr>(baseAttr);
|
||||||
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
|
std::vector<EnumAttrCase> cases = enumAttr.getAllCases();
|
||||||
|
|
||||||
body << llvm::formatv(enumAttrBeginPrinterCode,
|
body << formatv(enumAttrBeginPrinterCode,
|
||||||
(var->attr.isOptional() ? "*" : "") +
|
(var->attr.isOptional() ? "*" : "") +
|
||||||
op.getGetterName(var->name),
|
op.getGetterName(var->name),
|
||||||
enumAttr.getSymbolToStringFnName());
|
enumAttr.getSymbolToStringFnName());
|
||||||
@ -2276,9 +2276,8 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
|
|||||||
if (nonKeywordCases.test(it.index()))
|
if (nonKeywordCases.test(it.index()))
|
||||||
continue;
|
continue;
|
||||||
StringRef symbol = it.value().getSymbol();
|
StringRef symbol = it.value().getSymbol();
|
||||||
body << llvm::formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
|
body << formatv(" case {0}::{1}::{2}:\n", cppNamespace, enumName,
|
||||||
llvm::isDigit(symbol.front()) ? ("_" + symbol)
|
llvm::isDigit(symbol.front()) ? ("_" + symbol) : symbol);
|
||||||
: symbol);
|
|
||||||
}
|
}
|
||||||
body << " _odsPrinter << caseValueStr;\n"
|
body << " _odsPrinter << caseValueStr;\n"
|
||||||
" break;\n"
|
" break;\n"
|
||||||
@ -2584,7 +2583,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
|
|||||||
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
|
||||||
if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
|
if (auto *operand = dyn_cast<OperandVariable>(dir->getArg())) {
|
||||||
if (operand->getVar()->isVariadicOfVariadic()) {
|
if (operand->getVar()->isVariadicOfVariadic()) {
|
||||||
body << llvm::formatv(
|
body << formatv(
|
||||||
" ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
|
" ::llvm::interleaveComma({0}().getTypes(), _odsPrinter, "
|
||||||
"[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
|
"[&](::mlir::TypeRange types) {{ _odsPrinter << \"(\" << "
|
||||||
"types << \")\"; });\n",
|
"types << \")\"; });\n",
|
||||||
@ -2710,7 +2709,7 @@ private:
|
|||||||
/// Verify the state of operation operands within the format.
|
/// Verify the state of operation operands within the format.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
verifyOperands(SMLoc loc,
|
verifyOperands(SMLoc loc,
|
||||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||||
|
|
||||||
/// Verify the state of operation regions within the format.
|
/// Verify the state of operation regions within the format.
|
||||||
LogicalResult verifyRegions(SMLoc loc);
|
LogicalResult verifyRegions(SMLoc loc);
|
||||||
@ -2718,7 +2717,7 @@ private:
|
|||||||
/// Verify the state of operation results within the format.
|
/// Verify the state of operation results within the format.
|
||||||
LogicalResult
|
LogicalResult
|
||||||
verifyResults(SMLoc loc,
|
verifyResults(SMLoc loc,
|
||||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||||
|
|
||||||
/// Verify the state of operation successors within the format.
|
/// Verify the state of operation successors within the format.
|
||||||
LogicalResult verifySuccessors(SMLoc loc);
|
LogicalResult verifySuccessors(SMLoc loc);
|
||||||
@ -2730,18 +2729,17 @@ private:
|
|||||||
/// resolution.
|
/// resolution.
|
||||||
void handleAllTypesMatchConstraint(
|
void handleAllTypesMatchConstraint(
|
||||||
ArrayRef<StringRef> values,
|
ArrayRef<StringRef> values,
|
||||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver);
|
StringMap<TypeResolutionInstance> &variableTyResolver);
|
||||||
/// Check for inferable type resolution given all operands, and or results,
|
/// Check for inferable type resolution given all operands, and or results,
|
||||||
/// have the same type. If 'includeResults' is true, the results also have the
|
/// have the same type. If 'includeResults' is true, the results also have the
|
||||||
/// same type as all of the operands.
|
/// same type as all of the operands.
|
||||||
void handleSameTypesConstraint(
|
void handleSameTypesConstraint(
|
||||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||||
bool includeResults);
|
bool includeResults);
|
||||||
/// Check for inferable type resolution based on another operand, result, or
|
/// Check for inferable type resolution based on another operand, result, or
|
||||||
/// attribute.
|
/// attribute.
|
||||||
void handleTypesMatchConstraint(
|
void handleTypesMatchConstraint(
|
||||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);
|
||||||
const llvm::Record &def);
|
|
||||||
|
|
||||||
/// Returns an argument or attribute with the given name that has been seen
|
/// Returns an argument or attribute with the given name that has been seen
|
||||||
/// within the format.
|
/// within the format.
|
||||||
@ -2794,9 +2792,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
|
|||||||
"custom assembly format");
|
"custom assembly format");
|
||||||
|
|
||||||
// Check for any type traits that we can use for inferring types.
|
// Check for any type traits that we can use for inferring types.
|
||||||
llvm::StringMap<TypeResolutionInstance> variableTyResolver;
|
StringMap<TypeResolutionInstance> variableTyResolver;
|
||||||
for (const Trait &trait : op.getTraits()) {
|
for (const Trait &trait : op.getTraits()) {
|
||||||
const llvm::Record &def = trait.getDef();
|
const Record &def = trait.getDef();
|
||||||
if (def.isSubClassOf("AllTypesMatch")) {
|
if (def.isSubClassOf("AllTypesMatch")) {
|
||||||
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
|
handleAllTypesMatchConstraint(def.getValueAsListOfStrings("values"),
|
||||||
variableTyResolver);
|
variableTyResolver);
|
||||||
@ -2995,8 +2993,7 @@ OpFormatParser::verifyAttributeColonType(SMLoc loc,
|
|||||||
return false;
|
return false;
|
||||||
// If we encounter `:`, the range is known to be invalid.
|
// If we encounter `:`, the range is known to be invalid.
|
||||||
(void)emitError(
|
(void)emitError(
|
||||||
loc,
|
loc, formatv("format ambiguity caused by `:` literal found after "
|
||||||
llvm::formatv("format ambiguity caused by `:` literal found after "
|
|
||||||
"attribute `{0}` which does not have a buildable type",
|
"attribute `{0}` which does not have a buildable type",
|
||||||
cast<AttributeVariable>(base)->getVar()->name));
|
cast<AttributeVariable>(base)->getVar()->name));
|
||||||
return true;
|
return true;
|
||||||
@ -3018,7 +3015,7 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc,
|
|||||||
return false;
|
return false;
|
||||||
(void)emitErrorAndNote(
|
(void)emitErrorAndNote(
|
||||||
loc,
|
loc,
|
||||||
llvm::formatv("format ambiguity caused by `attr-dict` directive "
|
formatv("format ambiguity caused by `attr-dict` directive "
|
||||||
"followed by region `{0}`",
|
"followed by region `{0}`",
|
||||||
region->getVar()->name),
|
region->getVar()->name),
|
||||||
"try using `attr-dict-with-keyword` instead");
|
"try using `attr-dict-with-keyword` instead");
|
||||||
@ -3028,7 +3025,7 @@ OpFormatParser::verifyAttrDictRegion(SMLoc loc,
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult OpFormatParser::verifyOperands(
|
LogicalResult OpFormatParser::verifyOperands(
|
||||||
SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||||
// Check that all of the operands are within the format, and their types can
|
// Check that all of the operands are within the format, and their types can
|
||||||
// be inferred.
|
// be inferred.
|
||||||
auto &buildableTypes = fmt.buildableTypes;
|
auto &buildableTypes = fmt.buildableTypes;
|
||||||
@ -3093,7 +3090,7 @@ LogicalResult OpFormatParser::verifyRegions(SMLoc loc) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult OpFormatParser::verifyResults(
|
LogicalResult OpFormatParser::verifyResults(
|
||||||
SMLoc loc, llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
SMLoc loc, StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||||
// If we format all of the types together, there is nothing to check.
|
// If we format all of the types together, there is nothing to check.
|
||||||
if (fmt.allResultTypes)
|
if (fmt.allResultTypes)
|
||||||
return success();
|
return success();
|
||||||
@ -3197,7 +3194,7 @@ OpFormatParser::verifyOIListElements(SMLoc loc,
|
|||||||
|
|
||||||
void OpFormatParser::handleAllTypesMatchConstraint(
|
void OpFormatParser::handleAllTypesMatchConstraint(
|
||||||
ArrayRef<StringRef> values,
|
ArrayRef<StringRef> values,
|
||||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver) {
|
StringMap<TypeResolutionInstance> &variableTyResolver) {
|
||||||
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
for (unsigned i = 0, e = values.size(); i != e; ++i) {
|
||||||
// Check to see if this value matches a resolved operand or result type.
|
// Check to see if this value matches a resolved operand or result type.
|
||||||
ConstArgument arg = findSeenArg(values[i]);
|
ConstArgument arg = findSeenArg(values[i]);
|
||||||
@ -3213,7 +3210,7 @@ void OpFormatParser::handleAllTypesMatchConstraint(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void OpFormatParser::handleSameTypesConstraint(
|
void OpFormatParser::handleSameTypesConstraint(
|
||||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
StringMap<TypeResolutionInstance> &variableTyResolver,
|
||||||
bool includeResults) {
|
bool includeResults) {
|
||||||
const NamedTypeConstraint *resolver = nullptr;
|
const NamedTypeConstraint *resolver = nullptr;
|
||||||
int resolvedIt = -1;
|
int resolvedIt = -1;
|
||||||
@ -3238,8 +3235,7 @@ void OpFormatParser::handleSameTypesConstraint(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void OpFormatParser::handleTypesMatchConstraint(
|
void OpFormatParser::handleTypesMatchConstraint(
|
||||||
llvm::StringMap<TypeResolutionInstance> &variableTyResolver,
|
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
|
||||||
const llvm::Record &def) {
|
|
||||||
StringRef lhsName = def.getValueAsString("lhs");
|
StringRef lhsName = def.getValueAsString("lhs");
|
||||||
StringRef rhsName = def.getValueAsString("rhs");
|
StringRef rhsName = def.getValueAsString("rhs");
|
||||||
StringRef transformer = def.getValueAsString("transformer");
|
StringRef transformer = def.getValueAsString("transformer");
|
||||||
|
@ -41,7 +41,7 @@ static std::string getOperationName(const Record &def) {
|
|||||||
auto opName = def.getValueAsString("opName");
|
auto opName = def.getValueAsString("opName");
|
||||||
if (prefix.empty())
|
if (prefix.empty())
|
||||||
return std::string(opName);
|
return std::string(opName);
|
||||||
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
|
return std::string(formatv("{0}.{1}", prefix, opName));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<const Record *>
|
std::vector<const Record *>
|
||||||
@ -50,7 +50,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
|
|||||||
if (!classDef)
|
if (!classDef)
|
||||||
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
|
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
|
||||||
|
|
||||||
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
|
Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
|
||||||
std::vector<const Record *> defs;
|
std::vector<const Record *> defs;
|
||||||
for (const auto &def : recordKeeper.getDefs()) {
|
for (const auto &def : recordKeeper.getDefs()) {
|
||||||
if (!def.second->isSubClassOf(classDef))
|
if (!def.second->isSubClassOf(classDef))
|
||||||
@ -70,7 +70,7 @@ mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool mlir::tblgen::isPythonReserved(StringRef str) {
|
bool mlir::tblgen::isPythonReserved(StringRef str) {
|
||||||
static llvm::StringSet<> reserved({
|
static StringSet<> reserved({
|
||||||
"False", "None", "True", "and", "as", "assert", "async",
|
"False", "None", "True", "and", "as", "assert", "async",
|
||||||
"await", "break", "class", "continue", "def", "del", "elif",
|
"await", "break", "class", "continue", "def", "del", "elif",
|
||||||
"else", "except", "finally", "for", "from", "global", "if",
|
"else", "except", "finally", "for", "from", "global", "if",
|
||||||
@ -86,8 +86,8 @@ bool mlir::tblgen::isPythonReserved(StringRef str) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void mlir::tblgen::shardOpDefinitions(
|
void mlir::tblgen::shardOpDefinitions(
|
||||||
ArrayRef<const llvm::Record *> defs,
|
ArrayRef<const Record *> defs,
|
||||||
SmallVectorImpl<ArrayRef<const llvm::Record *>> &shardedDefs) {
|
SmallVectorImpl<ArrayRef<const Record *>> &shardedDefs) {
|
||||||
assert(opShardCount > 0 && "expected a positive shard count");
|
assert(opShardCount > 0 && "expected a positive shard count");
|
||||||
if (opShardCount == 1) {
|
if (opShardCount == 1) {
|
||||||
shardedDefs.push_back(defs);
|
shardedDefs.push_back(defs);
|
||||||
|
@ -23,6 +23,8 @@
|
|||||||
#include "llvm/TableGen/TableGenBackend.h"
|
#include "llvm/TableGen/TableGenBackend.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
using llvm::Record;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
using mlir::tblgen::Interface;
|
using mlir::tblgen::Interface;
|
||||||
using mlir::tblgen::InterfaceMethod;
|
using mlir::tblgen::InterfaceMethod;
|
||||||
using mlir::tblgen::OpInterface;
|
using mlir::tblgen::OpInterface;
|
||||||
@ -61,14 +63,13 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
|
|||||||
|
|
||||||
/// Get an array of all OpInterface definitions but exclude those subclassing
|
/// Get an array of all OpInterface definitions but exclude those subclassing
|
||||||
/// "DeclareOpInterfaceMethods".
|
/// "DeclareOpInterfaceMethods".
|
||||||
static std::vector<const llvm::Record *>
|
static std::vector<const Record *>
|
||||||
getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
|
getAllInterfaceDefinitions(const RecordKeeper &recordKeeper, StringRef name) {
|
||||||
StringRef name) {
|
std::vector<const Record *> defs =
|
||||||
std::vector<const llvm::Record *> defs =
|
|
||||||
recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
|
recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
|
||||||
|
|
||||||
std::string declareName = ("Declare" + name + "InterfaceMethods").str();
|
std::string declareName = ("Declare" + name + "InterfaceMethods").str();
|
||||||
llvm::erase_if(defs, [&](const llvm::Record *def) {
|
llvm::erase_if(defs, [&](const Record *def) {
|
||||||
// Ignore any "declare methods" interfaces.
|
// Ignore any "declare methods" interfaces.
|
||||||
if (def->isSubClassOf(declareName))
|
if (def->isSubClassOf(declareName))
|
||||||
return true;
|
return true;
|
||||||
@ -88,7 +89,7 @@ public:
|
|||||||
bool emitInterfaceDocs();
|
bool emitInterfaceDocs();
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
InterfaceGenerator(std::vector<const llvm::Record *> &&defs, raw_ostream &os)
|
InterfaceGenerator(std::vector<const Record *> &&defs, raw_ostream &os)
|
||||||
: defs(std::move(defs)), os(os) {}
|
: defs(std::move(defs)), os(os) {}
|
||||||
|
|
||||||
void emitConceptDecl(const Interface &interface);
|
void emitConceptDecl(const Interface &interface);
|
||||||
@ -99,7 +100,7 @@ protected:
|
|||||||
void emitInterfaceDecl(const Interface &interface);
|
void emitInterfaceDecl(const Interface &interface);
|
||||||
|
|
||||||
/// The set of interface records to emit.
|
/// The set of interface records to emit.
|
||||||
std::vector<const llvm::Record *> defs;
|
std::vector<const Record *> defs;
|
||||||
// The stream to emit to.
|
// The stream to emit to.
|
||||||
raw_ostream &os;
|
raw_ostream &os;
|
||||||
/// The C++ value type of the interface, e.g. Operation*.
|
/// The C++ value type of the interface, e.g. Operation*.
|
||||||
@ -118,7 +119,7 @@ protected:
|
|||||||
|
|
||||||
/// A specialized generator for attribute interfaces.
|
/// A specialized generator for attribute interfaces.
|
||||||
struct AttrInterfaceGenerator : public InterfaceGenerator {
|
struct AttrInterfaceGenerator : public InterfaceGenerator {
|
||||||
AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
AttrInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||||
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
|
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
|
||||||
valueType = "::mlir::Attribute";
|
valueType = "::mlir::Attribute";
|
||||||
interfaceBaseType = "AttributeInterface";
|
interfaceBaseType = "AttributeInterface";
|
||||||
@ -133,7 +134,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
|
|||||||
};
|
};
|
||||||
/// A specialized generator for operation interfaces.
|
/// A specialized generator for operation interfaces.
|
||||||
struct OpInterfaceGenerator : public InterfaceGenerator {
|
struct OpInterfaceGenerator : public InterfaceGenerator {
|
||||||
OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
OpInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||||
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
|
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
|
||||||
valueType = "::mlir::Operation *";
|
valueType = "::mlir::Operation *";
|
||||||
interfaceBaseType = "OpInterface";
|
interfaceBaseType = "OpInterface";
|
||||||
@ -149,7 +150,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
|
|||||||
};
|
};
|
||||||
/// A specialized generator for type interfaces.
|
/// A specialized generator for type interfaces.
|
||||||
struct TypeInterfaceGenerator : public InterfaceGenerator {
|
struct TypeInterfaceGenerator : public InterfaceGenerator {
|
||||||
TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
|
TypeInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
|
||||||
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
|
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
|
||||||
valueType = "::mlir::Type";
|
valueType = "::mlir::Type";
|
||||||
interfaceBaseType = "TypeInterface";
|
interfaceBaseType = "TypeInterface";
|
||||||
@ -607,13 +608,13 @@ bool InterfaceGenerator::emitInterfaceDecls() {
|
|||||||
llvm::emitSourceFileHeader("Interface Declarations", os);
|
llvm::emitSourceFileHeader("Interface Declarations", os);
|
||||||
// Sort according to ID, so defs are emitted in the order in which they appear
|
// Sort according to ID, so defs are emitted in the order in which they appear
|
||||||
// in the Tablegen file.
|
// in the Tablegen file.
|
||||||
std::vector<const llvm::Record *> sortedDefs(defs);
|
std::vector<const Record *> sortedDefs(defs);
|
||||||
llvm::sort(sortedDefs, [](const llvm::Record *lhs, const llvm::Record *rhs) {
|
llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
|
||||||
return lhs->getID() < rhs->getID();
|
return lhs->getID() < rhs->getID();
|
||||||
});
|
});
|
||||||
for (const llvm::Record *def : sortedDefs)
|
for (const Record *def : sortedDefs)
|
||||||
emitInterfaceDecl(Interface(def));
|
emitInterfaceDecl(Interface(def));
|
||||||
for (const llvm::Record *def : sortedDefs)
|
for (const Record *def : sortedDefs)
|
||||||
emitModelMethodsDef(Interface(def));
|
emitModelMethodsDef(Interface(def));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -622,8 +623,7 @@ bool InterfaceGenerator::emitInterfaceDecls() {
|
|||||||
// GEN: Interface documentation
|
// GEN: Interface documentation
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void emitInterfaceDoc(const llvm::Record &interfaceDef,
|
static void emitInterfaceDoc(const Record &interfaceDef, raw_ostream &os) {
|
||||||
raw_ostream &os) {
|
|
||||||
Interface interface(&interfaceDef);
|
Interface interface(&interfaceDef);
|
||||||
|
|
||||||
// Emit the interface name followed by the description.
|
// Emit the interface name followed by the description.
|
||||||
@ -684,15 +684,15 @@ struct InterfaceGenRegistration {
|
|||||||
genDefDesc(("Generate " + genDesc + " interface definitions").str()),
|
genDefDesc(("Generate " + genDesc + " interface definitions").str()),
|
||||||
genDocDesc(("Generate " + genDesc + " interface documentation").str()),
|
genDocDesc(("Generate " + genDesc + " interface documentation").str()),
|
||||||
genDecls(genDeclArg, genDeclDesc,
|
genDecls(genDeclArg, genDeclDesc,
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
return GeneratorT(records, os).emitInterfaceDecls();
|
return GeneratorT(records, os).emitInterfaceDecls();
|
||||||
}),
|
}),
|
||||||
genDefs(genDefArg, genDefDesc,
|
genDefs(genDefArg, genDefDesc,
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
return GeneratorT(records, os).emitInterfaceDefs();
|
return GeneratorT(records, os).emitInterfaceDefs();
|
||||||
}),
|
}),
|
||||||
genDocs(genDocArg, genDocDesc,
|
genDocs(genDocArg, genDocDesc,
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
return GeneratorT(records, os).emitInterfaceDocs();
|
return GeneratorT(records, os).emitInterfaceDocs();
|
||||||
}) {}
|
}) {}
|
||||||
|
|
||||||
|
@ -23,6 +23,9 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::formatv;
|
||||||
|
using llvm::Record;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
|
|
||||||
/// File header and includes.
|
/// File header and includes.
|
||||||
/// {0} is the dialect namespace.
|
/// {0} is the dialect namespace.
|
||||||
@ -315,9 +318,9 @@ static std::string sanitizeName(StringRef name) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static std::string attrSizedTraitForKind(const char *kind) {
|
static std::string attrSizedTraitForKind(const char *kind) {
|
||||||
return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
|
return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
|
||||||
llvm::StringRef(kind).take_front().upper(),
|
StringRef(kind).take_front().upper(),
|
||||||
llvm::StringRef(kind).drop_front());
|
StringRef(kind).drop_front());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emits accessors to "elements" of an Op definition. Currently, the supported
|
/// Emits accessors to "elements" of an Op definition. Currently, the supported
|
||||||
@ -328,15 +331,14 @@ static void emitElementAccessors(
|
|||||||
unsigned numVariadicGroups, unsigned numElements,
|
unsigned numVariadicGroups, unsigned numElements,
|
||||||
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
||||||
getElement) {
|
getElement) {
|
||||||
assert(llvm::is_contained(
|
assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
|
||||||
llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
|
kind) &&
|
||||||
"unsupported kind");
|
"unsupported kind");
|
||||||
|
|
||||||
// Traits indicating how to process variadic elements.
|
// Traits indicating how to process variadic elements.
|
||||||
std::string sameSizeTrait =
|
std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
|
||||||
llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
|
StringRef(kind).take_front().upper(),
|
||||||
llvm::StringRef(kind).take_front().upper(),
|
StringRef(kind).drop_front());
|
||||||
llvm::StringRef(kind).drop_front());
|
|
||||||
std::string attrSizedTrait = attrSizedTraitForKind(kind);
|
std::string attrSizedTrait = attrSizedTraitForKind(kind);
|
||||||
|
|
||||||
// If there is only one variable-length element group, its size can be
|
// If there is only one variable-length element group, its size can be
|
||||||
@ -351,15 +353,14 @@ static void emitElementAccessors(
|
|||||||
if (element.name.empty())
|
if (element.name.empty())
|
||||||
continue;
|
continue;
|
||||||
if (element.isVariableLength()) {
|
if (element.isVariableLength()) {
|
||||||
os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
|
os << formatv(element.isOptional() ? opOneOptionalTemplate
|
||||||
: opOneVariadicTemplate,
|
: opOneVariadicTemplate,
|
||||||
sanitizeName(element.name), kind, numElements, i);
|
sanitizeName(element.name), kind, numElements, i);
|
||||||
} else if (seenVariableLength) {
|
} else if (seenVariableLength) {
|
||||||
os << llvm::formatv(opSingleAfterVariableTemplate,
|
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
|
||||||
sanitizeName(element.name), kind, numElements, i);
|
kind, numElements, i);
|
||||||
} else {
|
} else {
|
||||||
os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
|
os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
|
||||||
i);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@ -382,11 +383,10 @@ static void emitElementAccessors(
|
|||||||
for (unsigned i = 0; i < numElements; ++i) {
|
for (unsigned i = 0; i < numElements; ++i) {
|
||||||
const NamedTypeConstraint &element = getElement(op, i);
|
const NamedTypeConstraint &element = getElement(op, i);
|
||||||
if (!element.name.empty()) {
|
if (!element.name.empty()) {
|
||||||
os << llvm::formatv(opVariadicEqualPrefixTemplate,
|
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
|
||||||
sanitizeName(element.name), kind, numSimpleLength,
|
kind, numSimpleLength, numVariadicGroups,
|
||||||
numVariadicGroups, numPrecedingSimple,
|
numPrecedingSimple, numPrecedingVariadic);
|
||||||
numPrecedingVariadic);
|
os << formatv(element.isVariableLength()
|
||||||
os << llvm::formatv(element.isVariableLength()
|
|
||||||
? opVariadicEqualVariadicTemplate
|
? opVariadicEqualVariadicTemplate
|
||||||
: opVariadicEqualSimpleTemplate,
|
: opVariadicEqualSimpleTemplate,
|
||||||
kind);
|
kind);
|
||||||
@ -412,9 +412,9 @@ static void emitElementAccessors(
|
|||||||
trailing = "[0]";
|
trailing = "[0]";
|
||||||
else if (element.isOptional())
|
else if (element.isOptional())
|
||||||
trailing = std::string(
|
trailing = std::string(
|
||||||
llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
|
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
|
||||||
os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
|
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
|
||||||
kind, i, trailing);
|
i, trailing);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -459,27 +459,21 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
|
|||||||
|
|
||||||
// Unit attributes are handled specially.
|
// Unit attributes are handled specially.
|
||||||
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
||||||
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,
|
os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name);
|
||||||
namedAttr.name);
|
os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name);
|
||||||
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,
|
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
|
||||||
namedAttr.name);
|
|
||||||
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
|
|
||||||
namedAttr.name);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (namedAttr.attr.isOptional()) {
|
if (namedAttr.attr.isOptional()) {
|
||||||
os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,
|
os << formatv(optionalAttributeGetterTemplate, sanitizedName,
|
||||||
namedAttr.name);
|
namedAttr.name);
|
||||||
os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,
|
os << formatv(optionalAttributeSetterTemplate, sanitizedName,
|
||||||
namedAttr.name);
|
|
||||||
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,
|
|
||||||
namedAttr.name);
|
namedAttr.name);
|
||||||
|
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
|
||||||
} else {
|
} else {
|
||||||
os << llvm::formatv(attributeGetterTemplate, sanitizedName,
|
os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
|
||||||
namedAttr.name);
|
os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
|
||||||
os << llvm::formatv(attributeSetterTemplate, sanitizedName,
|
|
||||||
namedAttr.name);
|
|
||||||
// Non-optional attributes cannot be deleted.
|
// Non-optional attributes cannot be deleted.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -595,7 +589,7 @@ static bool canInferType(const Operator &op) {
|
|||||||
/// accept them as arguments.
|
/// accept them as arguments.
|
||||||
static void
|
static void
|
||||||
populateBuilderArgsResults(const Operator &op,
|
populateBuilderArgsResults(const Operator &op,
|
||||||
llvm::SmallVectorImpl<std::string> &builderArgs) {
|
SmallVectorImpl<std::string> &builderArgs) {
|
||||||
if (canInferType(op))
|
if (canInferType(op))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@ -607,7 +601,7 @@ populateBuilderArgsResults(const Operator &op,
|
|||||||
// to properly match the built-in result accessor.
|
// to properly match the built-in result accessor.
|
||||||
name = "result";
|
name = "result";
|
||||||
} else {
|
} else {
|
||||||
name = llvm::formatv("_gen_res_{0}", i);
|
name = formatv("_gen_res_{0}", i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
name = sanitizeName(name);
|
name = sanitizeName(name);
|
||||||
@ -620,14 +614,13 @@ populateBuilderArgsResults(const Operator &op,
|
|||||||
/// appear in the `arguments` field of the op definition. Additionally,
|
/// appear in the `arguments` field of the op definition. Additionally,
|
||||||
/// `operandNames` is populated with names of operands in their order of
|
/// `operandNames` is populated with names of operands in their order of
|
||||||
/// appearance.
|
/// appearance.
|
||||||
static void
|
static void populateBuilderArgs(const Operator &op,
|
||||||
populateBuilderArgs(const Operator &op,
|
SmallVectorImpl<std::string> &builderArgs,
|
||||||
llvm::SmallVectorImpl<std::string> &builderArgs,
|
SmallVectorImpl<std::string> &operandNames) {
|
||||||
llvm::SmallVectorImpl<std::string> &operandNames) {
|
|
||||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||||
std::string name = op.getArgName(i).str();
|
std::string name = op.getArgName(i).str();
|
||||||
if (name.empty())
|
if (name.empty())
|
||||||
name = llvm::formatv("_gen_arg_{0}", i);
|
name = formatv("_gen_arg_{0}", i);
|
||||||
name = sanitizeName(name);
|
name = sanitizeName(name);
|
||||||
builderArgs.push_back(name);
|
builderArgs.push_back(name);
|
||||||
if (!op.getArg(i).is<NamedAttribute *>())
|
if (!op.getArg(i).is<NamedAttribute *>())
|
||||||
@ -637,15 +630,16 @@ populateBuilderArgs(const Operator &op,
|
|||||||
|
|
||||||
/// Populates `builderArgs` with the Python-compatible names of builder function
|
/// Populates `builderArgs` with the Python-compatible names of builder function
|
||||||
/// successor arguments. Additionally, `successorArgNames` is also populated.
|
/// successor arguments. Additionally, `successorArgNames` is also populated.
|
||||||
static void populateBuilderArgsSuccessors(
|
static void
|
||||||
const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
|
populateBuilderArgsSuccessors(const Operator &op,
|
||||||
llvm::SmallVectorImpl<std::string> &successorArgNames) {
|
SmallVectorImpl<std::string> &builderArgs,
|
||||||
|
SmallVectorImpl<std::string> &successorArgNames) {
|
||||||
|
|
||||||
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
|
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
|
||||||
NamedSuccessor successor = op.getSuccessor(i);
|
NamedSuccessor successor = op.getSuccessor(i);
|
||||||
std::string name = std::string(successor.name);
|
std::string name = std::string(successor.name);
|
||||||
if (name.empty())
|
if (name.empty())
|
||||||
name = llvm::formatv("_gen_successor_{0}", i);
|
name = formatv("_gen_successor_{0}", i);
|
||||||
name = sanitizeName(name);
|
name = sanitizeName(name);
|
||||||
builderArgs.push_back(name);
|
builderArgs.push_back(name);
|
||||||
successorArgNames.push_back(name);
|
successorArgNames.push_back(name);
|
||||||
@ -658,9 +652,8 @@ static void populateBuilderArgsSuccessors(
|
|||||||
/// operands and attributes in the same order as they appear in the `arguments`
|
/// operands and attributes in the same order as they appear in the `arguments`
|
||||||
/// field.
|
/// field.
|
||||||
static void
|
static void
|
||||||
populateBuilderLinesAttr(const Operator &op,
|
populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
|
||||||
llvm::ArrayRef<std::string> argNames,
|
SmallVectorImpl<std::string> &builderLines) {
|
||||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
|
||||||
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
|
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
|
||||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||||
Argument arg = op.getArg(i);
|
Argument arg = op.getArg(i);
|
||||||
@ -670,12 +663,12 @@ populateBuilderLinesAttr(const Operator &op,
|
|||||||
|
|
||||||
// Unit attributes are handled specially.
|
// Unit attributes are handled specially.
|
||||||
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
||||||
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate,
|
builderLines.push_back(
|
||||||
attribute->name, argNames[i]));
|
formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
builderLines.push_back(llvm::formatv(
|
builderLines.push_back(formatv(
|
||||||
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
|
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
|
||||||
? initOptionalAttributeWithBuilderTemplate
|
? initOptionalAttributeWithBuilderTemplate
|
||||||
: initAttributeWithBuilderTemplate,
|
: initAttributeWithBuilderTemplate,
|
||||||
@ -686,30 +679,30 @@ populateBuilderLinesAttr(const Operator &op,
|
|||||||
/// Populates `builderLines` with additional lines that are required in the
|
/// Populates `builderLines` with additional lines that are required in the
|
||||||
/// builder to set up successors. successorArgNames is expected to correspond
|
/// builder to set up successors. successorArgNames is expected to correspond
|
||||||
/// to the Python argument name for each successor on the op.
|
/// to the Python argument name for each successor on the op.
|
||||||
static void populateBuilderLinesSuccessors(
|
static void
|
||||||
const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
|
populateBuilderLinesSuccessors(const Operator &op,
|
||||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
ArrayRef<std::string> successorArgNames,
|
||||||
|
SmallVectorImpl<std::string> &builderLines) {
|
||||||
if (successorArgNames.empty()) {
|
if (successorArgNames.empty()) {
|
||||||
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
|
builderLines.push_back(formatv(initSuccessorsTemplate, "None"));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
|
builderLines.push_back(formatv(initSuccessorsTemplate, "[]"));
|
||||||
for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
|
for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
|
||||||
auto &argName = successorArgNames[i];
|
auto &argName = successorArgNames[i];
|
||||||
const NamedSuccessor &successor = op.getSuccessor(i);
|
const NamedSuccessor &successor = op.getSuccessor(i);
|
||||||
builderLines.push_back(
|
builderLines.push_back(formatv(addSuccessorTemplate,
|
||||||
llvm::formatv(addSuccessorTemplate,
|
successor.isVariadic() ? "extend" : "append",
|
||||||
successor.isVariadic() ? "extend" : "append", argName));
|
argName));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Populates `builderLines` with additional lines that are required in the
|
/// Populates `builderLines` with additional lines that are required in the
|
||||||
/// builder to set up op operands.
|
/// builder to set up op operands.
|
||||||
static void
|
static void
|
||||||
populateBuilderLinesOperand(const Operator &op,
|
populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
|
||||||
llvm::ArrayRef<std::string> names,
|
SmallVectorImpl<std::string> &builderLines) {
|
||||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
|
||||||
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
|
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
|
||||||
|
|
||||||
// For each element, find or generate a name.
|
// For each element, find or generate a name.
|
||||||
@ -718,7 +711,7 @@ populateBuilderLinesOperand(const Operator &op,
|
|||||||
std::string name = names[i];
|
std::string name = names[i];
|
||||||
|
|
||||||
// Choose the formatting string based on the element kind.
|
// Choose the formatting string based on the element kind.
|
||||||
llvm::StringRef formatString;
|
StringRef formatString;
|
||||||
if (!element.isVariableLength()) {
|
if (!element.isVariableLength()) {
|
||||||
formatString = singleOperandAppendTemplate;
|
formatString = singleOperandAppendTemplate;
|
||||||
} else if (element.isOptional()) {
|
} else if (element.isOptional()) {
|
||||||
@ -738,7 +731,7 @@ populateBuilderLinesOperand(const Operator &op,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
builderLines.push_back(llvm::formatv(formatString.data(), name));
|
builderLines.push_back(formatv(formatString.data(), name));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -758,7 +751,7 @@ constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
|
|||||||
/// Appends the given multiline string as individual strings into
|
/// Appends the given multiline string as individual strings into
|
||||||
/// `builderLines`.
|
/// `builderLines`.
|
||||||
static void appendLineByLine(StringRef string,
|
static void appendLineByLine(StringRef string,
|
||||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
SmallVectorImpl<std::string> &builderLines) {
|
||||||
|
|
||||||
std::pair<StringRef, StringRef> split = std::make_pair(string, string);
|
std::pair<StringRef, StringRef> split = std::make_pair(string, string);
|
||||||
do {
|
do {
|
||||||
@ -770,14 +763,13 @@ static void appendLineByLine(StringRef string,
|
|||||||
/// Populates `builderLines` with additional lines that are required in the
|
/// Populates `builderLines` with additional lines that are required in the
|
||||||
/// builder to set up op results.
|
/// builder to set up op results.
|
||||||
static void
|
static void
|
||||||
populateBuilderLinesResult(const Operator &op,
|
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
|
||||||
llvm::ArrayRef<std::string> names,
|
SmallVectorImpl<std::string> &builderLines) {
|
||||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
|
||||||
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
|
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
|
||||||
|
|
||||||
if (hasSameArgumentAndResultTypes(op)) {
|
if (hasSameArgumentAndResultTypes(op)) {
|
||||||
builderLines.push_back(llvm::formatv(
|
builderLines.push_back(formatv(appendSameResultsTemplate,
|
||||||
appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
|
"operands[0].type", op.getNumResults()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -785,10 +777,9 @@ populateBuilderLinesResult(const Operator &op,
|
|||||||
const NamedAttribute &firstAttr = op.getAttribute(0);
|
const NamedAttribute &firstAttr = op.getAttribute(0);
|
||||||
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
|
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
|
||||||
"from which the type is derived");
|
"from which the type is derived");
|
||||||
appendLineByLine(
|
appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
|
||||||
llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
|
|
||||||
builderLines);
|
builderLines);
|
||||||
builderLines.push_back(llvm::formatv(appendSameResultsTemplate,
|
builderLines.push_back(formatv(appendSameResultsTemplate,
|
||||||
"_ods_derived_result_type",
|
"_ods_derived_result_type",
|
||||||
op.getNumResults()));
|
op.getNumResults()));
|
||||||
return;
|
return;
|
||||||
@ -803,7 +794,7 @@ populateBuilderLinesResult(const Operator &op,
|
|||||||
std::string name = names[i];
|
std::string name = names[i];
|
||||||
|
|
||||||
// Choose the formatting string based on the element kind.
|
// Choose the formatting string based on the element kind.
|
||||||
llvm::StringRef formatString;
|
StringRef formatString;
|
||||||
if (!element.isVariableLength()) {
|
if (!element.isVariableLength()) {
|
||||||
formatString = singleResultAppendTemplate;
|
formatString = singleResultAppendTemplate;
|
||||||
} else if (element.isOptional()) {
|
} else if (element.isOptional()) {
|
||||||
@ -819,17 +810,16 @@ populateBuilderLinesResult(const Operator &op,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
builderLines.push_back(llvm::formatv(formatString.data(), name));
|
builderLines.push_back(formatv(formatString.data(), name));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If the operation has variadic regions, adds a builder argument to specify
|
/// If the operation has variadic regions, adds a builder argument to specify
|
||||||
/// the number of those regions and builder lines to forward it to the generic
|
/// the number of those regions and builder lines to forward it to the generic
|
||||||
/// constructor.
|
/// constructor.
|
||||||
static void
|
static void populateBuilderRegions(const Operator &op,
|
||||||
populateBuilderRegions(const Operator &op,
|
SmallVectorImpl<std::string> &builderArgs,
|
||||||
llvm::SmallVectorImpl<std::string> &builderArgs,
|
SmallVectorImpl<std::string> &builderLines) {
|
||||||
llvm::SmallVectorImpl<std::string> &builderLines) {
|
|
||||||
if (op.hasNoVariadicRegions())
|
if (op.hasNoVariadicRegions())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@ -844,19 +834,19 @@ populateBuilderRegions(const Operator &op,
|
|||||||
.str();
|
.str();
|
||||||
builderArgs.push_back(name);
|
builderArgs.push_back(name);
|
||||||
builderLines.push_back(
|
builderLines.push_back(
|
||||||
llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
|
formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Emits a default builder constructing an operation from the list of its
|
/// Emits a default builder constructing an operation from the list of its
|
||||||
/// result types, followed by a list of its operands. Returns vector
|
/// result types, followed by a list of its operands. Returns vector
|
||||||
/// of fully built functionArgs for downstream users (to save having to
|
/// of fully built functionArgs for downstream users (to save having to
|
||||||
/// rebuild anew).
|
/// rebuild anew).
|
||||||
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
llvm::SmallVector<std::string> builderArgs;
|
SmallVector<std::string> builderArgs;
|
||||||
llvm::SmallVector<std::string> builderLines;
|
SmallVector<std::string> builderLines;
|
||||||
llvm::SmallVector<std::string> operandArgNames;
|
SmallVector<std::string> operandArgNames;
|
||||||
llvm::SmallVector<std::string> successorArgNames;
|
SmallVector<std::string> successorArgNames;
|
||||||
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
|
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
|
||||||
op.getNumNativeAttributes() + op.getNumSuccessors());
|
op.getNumNativeAttributes() + op.getNumSuccessors());
|
||||||
populateBuilderArgsResults(op, builderArgs);
|
populateBuilderArgsResults(op, builderArgs);
|
||||||
@ -866,10 +856,10 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
|||||||
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
|
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
|
||||||
|
|
||||||
populateBuilderLinesOperand(op, operandArgNames, builderLines);
|
populateBuilderLinesOperand(op, operandArgNames, builderLines);
|
||||||
populateBuilderLinesAttr(
|
populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
|
||||||
op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines);
|
builderLines);
|
||||||
populateBuilderLinesResult(
|
populateBuilderLinesResult(
|
||||||
op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
|
op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
|
||||||
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
|
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
|
||||||
populateBuilderRegions(op, builderArgs, builderLines);
|
populateBuilderRegions(op, builderArgs, builderLines);
|
||||||
|
|
||||||
@ -896,7 +886,7 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
|||||||
};
|
};
|
||||||
|
|
||||||
// StringRefs in functionArgs refer to strings allocated by builderArgs.
|
// StringRefs in functionArgs refer to strings allocated by builderArgs.
|
||||||
llvm::SmallVector<llvm::StringRef> functionArgs;
|
SmallVector<StringRef> functionArgs;
|
||||||
|
|
||||||
// Add positional arguments.
|
// Add positional arguments.
|
||||||
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
|
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
|
||||||
@ -929,11 +919,10 @@ static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
|||||||
initArgs.push_back("loc=loc");
|
initArgs.push_back("loc=loc");
|
||||||
initArgs.push_back("ip=ip");
|
initArgs.push_back("ip=ip");
|
||||||
|
|
||||||
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
|
os << formatv(initTemplate, llvm::join(functionArgs, ", "),
|
||||||
llvm::join(builderLines, "\n "),
|
llvm::join(builderLines, "\n "), llvm::join(initArgs, ", "));
|
||||||
llvm::join(initArgs, ", "));
|
|
||||||
return llvm::to_vector<8>(
|
return llvm::to_vector<8>(
|
||||||
llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
|
llvm::map_range(functionArgs, [](StringRef s) { return s.str(); }));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void emitSegmentSpec(
|
static void emitSegmentSpec(
|
||||||
@ -955,14 +944,14 @@ static void emitSegmentSpec(
|
|||||||
}
|
}
|
||||||
segmentSpec.append("]");
|
segmentSpec.append("]");
|
||||||
|
|
||||||
os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
|
os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
|
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
|
||||||
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
|
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
|
||||||
// Note that the base OpView class defines this as (0, True).
|
// Note that the base OpView class defines this as (0, True).
|
||||||
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
|
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
|
||||||
os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
|
os << formatv(opClassRegionSpecTemplate, minRegionCount,
|
||||||
op.hasNoVariadicRegions() ? "True" : "False");
|
op.hasNoVariadicRegions() ? "True" : "False");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -975,7 +964,7 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
|
|||||||
|
|
||||||
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
|
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
|
||||||
"expected only the last region to be variadic");
|
"expected only the last region to be variadic");
|
||||||
os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name),
|
os << formatv(regionAccessorTemplate, sanitizeName(region.name),
|
||||||
std::to_string(en.index()) +
|
std::to_string(en.index()) +
|
||||||
(region.isVariadic() ? ":" : ""));
|
(region.isVariadic() ? ":" : ""));
|
||||||
}
|
}
|
||||||
@ -983,12 +972,12 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
|
|||||||
|
|
||||||
/// Emits builder that extracts results from op
|
/// Emits builder that extracts results from op
|
||||||
static void emitValueBuilder(const Operator &op,
|
static void emitValueBuilder(const Operator &op,
|
||||||
llvm::SmallVector<std::string> functionArgs,
|
SmallVector<std::string> functionArgs,
|
||||||
raw_ostream &os) {
|
raw_ostream &os) {
|
||||||
// Params with (possibly) default args.
|
// Params with (possibly) default args.
|
||||||
auto valueBuilderParams =
|
auto valueBuilderParams =
|
||||||
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
|
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
|
||||||
llvm::SmallVector<llvm::StringRef> argMaybeDefault =
|
SmallVector<StringRef> argMaybeDefault =
|
||||||
llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
|
llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
|
||||||
auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
|
auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
|
||||||
if (argMaybeDefault.size() == 2)
|
if (argMaybeDefault.size() == 2)
|
||||||
@ -1005,7 +994,7 @@ static void emitValueBuilder(const Operator &op,
|
|||||||
});
|
});
|
||||||
std::string nameWithoutDialect =
|
std::string nameWithoutDialect =
|
||||||
op.getOperationName().substr(op.getOperationName().find('.') + 1);
|
op.getOperationName().substr(op.getOperationName().find('.') + 1);
|
||||||
os << llvm::formatv(
|
os << formatv(
|
||||||
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
|
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
|
||||||
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
|
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
|
||||||
llvm::join(opBuilderArgs, ", "),
|
llvm::join(opBuilderArgs, ", "),
|
||||||
@ -1016,8 +1005,7 @@ static void emitValueBuilder(const Operator &op,
|
|||||||
|
|
||||||
/// Emits bindings for a specific Op to the given output stream.
|
/// Emits bindings for a specific Op to the given output stream.
|
||||||
static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
||||||
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
|
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName());
|
||||||
op.getOperationName());
|
|
||||||
|
|
||||||
// Sized segments.
|
// Sized segments.
|
||||||
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
|
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
|
||||||
@ -1028,7 +1016,7 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
emitRegionAttributes(op, os);
|
emitRegionAttributes(op, os);
|
||||||
llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
|
SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
|
||||||
emitOperandAccessors(op, os);
|
emitOperandAccessors(op, os);
|
||||||
emitAttributeAccessors(op, os);
|
emitAttributeAccessors(op, os);
|
||||||
emitResultAccessors(op, os);
|
emitResultAccessors(op, os);
|
||||||
@ -1039,17 +1027,17 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
|||||||
/// Emits bindings for the dialect specified in the command line, including file
|
/// Emits bindings for the dialect specified in the command line, including file
|
||||||
/// headers and utilities. Returns `false` on success to comply with Tablegen
|
/// headers and utilities. Returns `false` on success to comply with Tablegen
|
||||||
/// registration requirements.
|
/// registration requirements.
|
||||||
static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
|
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
|
||||||
if (clDialectName.empty())
|
if (clDialectName.empty())
|
||||||
llvm::PrintFatalError("dialect name not provided");
|
llvm::PrintFatalError("dialect name not provided");
|
||||||
|
|
||||||
os << fileHeader;
|
os << fileHeader;
|
||||||
if (!clDialectExtensionName.empty())
|
if (!clDialectExtensionName.empty())
|
||||||
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
|
os << formatv(dialectExtensionTemplate, clDialectName.getValue());
|
||||||
else
|
else
|
||||||
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
|
os << formatv(dialectClassTemplate, clDialectName.getValue());
|
||||||
|
|
||||||
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
|
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||||
Operator op(rec);
|
Operator op(rec);
|
||||||
if (op.getDialectName() == clDialectName.getValue())
|
if (op.getDialectName() == clDialectName.getValue())
|
||||||
emitOpBindings(op, os);
|
emitOpBindings(op, os);
|
||||||
|
@ -20,6 +20,8 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::formatv;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
|
|
||||||
static llvm::cl::OptionCategory
|
static llvm::cl::OptionCategory
|
||||||
passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl");
|
passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl");
|
||||||
@ -56,7 +58,7 @@ const char *const fileFooter = R"(
|
|||||||
)";
|
)";
|
||||||
|
|
||||||
/// Emit TODO
|
/// Emit TODO
|
||||||
static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
|
static bool emitCAPIHeader(const RecordKeeper &records, raw_ostream &os) {
|
||||||
os << fileHeader;
|
os << fileHeader;
|
||||||
os << "// Registration for the entire group\n";
|
os << "// Registration for the entire group\n";
|
||||||
os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName
|
os << "MLIR_CAPI_EXPORTED void mlirRegister" << groupName
|
||||||
@ -64,7 +66,7 @@ static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
|
|||||||
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
|
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
|
||||||
Pass pass(def);
|
Pass pass(def);
|
||||||
StringRef defName = pass.getDef()->getName();
|
StringRef defName = pass.getDef()->getName();
|
||||||
os << llvm::formatv(passDecl, groupName, defName);
|
os << formatv(passDecl, groupName, defName);
|
||||||
}
|
}
|
||||||
os << fileFooter;
|
os << fileFooter;
|
||||||
return false;
|
return false;
|
||||||
@ -91,9 +93,9 @@ void mlirRegister{0}Passes(void) {{
|
|||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
|
static bool emitCAPIImpl(const RecordKeeper &records, raw_ostream &os) {
|
||||||
os << "/* Autogenerated by mlir-tblgen; don't manually edit. */";
|
os << "/* Autogenerated by mlir-tblgen; don't manually edit. */";
|
||||||
os << llvm::formatv(passGroupRegistrationCode, groupName);
|
os << formatv(passGroupRegistrationCode, groupName);
|
||||||
|
|
||||||
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
|
for (const auto *def : records.getAllDerivedDefinitions("PassBase")) {
|
||||||
Pass pass(def);
|
Pass pass(def);
|
||||||
@ -103,10 +105,9 @@ static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
|
|||||||
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
|
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
|
||||||
constructorCall = constructor.str();
|
constructorCall = constructor.str();
|
||||||
else
|
else
|
||||||
constructorCall =
|
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
|
||||||
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
|
|
||||||
|
|
||||||
os << llvm::formatv(passCreateDef, groupName, defName, constructorCall);
|
os << formatv(passCreateDef, groupName, defName, constructorCall);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
|
|
||||||
/// Emit the documentation for the given pass.
|
/// Emit the documentation for the given pass.
|
||||||
static void emitDoc(const Pass &pass, raw_ostream &os) {
|
static void emitDoc(const Pass &pass, raw_ostream &os) {
|
||||||
@ -56,7 +57,7 @@ static void emitDoc(const Pass &pass, raw_ostream &os) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
static void emitDocs(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||||
auto passDefs = recordKeeper.getAllDerivedDefinitions("PassBase");
|
auto passDefs = recordKeeper.getAllDerivedDefinitions("PassBase");
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ static void emitDocs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||||||
|
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genRegister("gen-pass-doc", "Generate pass documentation",
|
genRegister("gen-pass-doc", "Generate pass documentation",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
emitDocs(records, os);
|
emitDocs(records, os);
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
@ -21,6 +21,8 @@
|
|||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::tblgen;
|
using namespace mlir::tblgen;
|
||||||
|
using llvm::formatv;
|
||||||
|
using llvm::RecordKeeper;
|
||||||
|
|
||||||
static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
|
static llvm::cl::OptionCategory passGenCat("Options for -gen-pass-decls");
|
||||||
static llvm::cl::opt<std::string>
|
static llvm::cl::opt<std::string>
|
||||||
@ -28,7 +30,7 @@ static llvm::cl::opt<std::string>
|
|||||||
llvm::cl::cat(passGenCat));
|
llvm::cl::cat(passGenCat));
|
||||||
|
|
||||||
/// Extract the list of passes from the TableGen records.
|
/// Extract the list of passes from the TableGen records.
|
||||||
static std::vector<Pass> getPasses(const llvm::RecordKeeper &recordKeeper) {
|
static std::vector<Pass> getPasses(const RecordKeeper &recordKeeper) {
|
||||||
std::vector<Pass> passes;
|
std::vector<Pass> passes;
|
||||||
|
|
||||||
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
|
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
|
||||||
@ -91,7 +93,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
|
|||||||
if (options.empty())
|
if (options.empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
os << llvm::formatv("struct {0}Options {{\n", passName);
|
os << formatv("struct {0}Options {{\n", passName);
|
||||||
|
|
||||||
for (const PassOption &opt : options) {
|
for (const PassOption &opt : options) {
|
||||||
std::string type = opt.getType().str();
|
std::string type = opt.getType().str();
|
||||||
@ -99,7 +101,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) {
|
|||||||
if (opt.isListOption())
|
if (opt.isListOption())
|
||||||
type = "::llvm::SmallVector<" + type + ">";
|
type = "::llvm::SmallVector<" + type + ">";
|
||||||
|
|
||||||
os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName());
|
os.indent(2) << formatv("{0} {1}", type, opt.getCppVariableName());
|
||||||
|
|
||||||
if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
|
if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
|
||||||
os << " = " << defaultVal;
|
os << " = " << defaultVal;
|
||||||
@ -128,7 +130,7 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) {
|
|||||||
|
|
||||||
// Declaration of the constructor with options.
|
// Declaration of the constructor with options.
|
||||||
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
|
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty())
|
||||||
os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}("
|
os << formatv("std::unique_ptr<::mlir::Pass> create{0}("
|
||||||
"{0}Options options);\n",
|
"{0}Options options);\n",
|
||||||
passName);
|
passName);
|
||||||
}
|
}
|
||||||
@ -147,14 +149,13 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
|
|||||||
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
|
if (StringRef constructor = pass.getConstructor(); !constructor.empty())
|
||||||
constructorCall = constructor.str();
|
constructorCall = constructor.str();
|
||||||
else
|
else
|
||||||
constructorCall =
|
constructorCall = formatv("create{0}()", pass.getDef()->getName()).str();
|
||||||
llvm::formatv("create{0}()", pass.getDef()->getName()).str();
|
|
||||||
|
|
||||||
os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(),
|
os << formatv(passRegistrationCode, pass.getDef()->getName(),
|
||||||
constructorCall);
|
constructorCall);
|
||||||
}
|
}
|
||||||
|
|
||||||
os << llvm::formatv(passGroupRegistrationCode, groupName);
|
os << formatv(passGroupRegistrationCode, groupName);
|
||||||
|
|
||||||
for (const Pass &pass : passes)
|
for (const Pass &pass : passes)
|
||||||
os << " register" << pass.getDef()->getName() << "();\n";
|
os << " register" << pass.getDef()->getName() << "();\n";
|
||||||
@ -270,9 +271,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
|||||||
os.indent(2) << "::mlir::Pass::"
|
os.indent(2) << "::mlir::Pass::"
|
||||||
<< (opt.isListOption() ? "ListOption" : "Option");
|
<< (opt.isListOption() ? "ListOption" : "Option");
|
||||||
|
|
||||||
os << llvm::formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
|
os << formatv(R"(<{0}> {1}{{*this, "{2}", ::llvm::cl::desc("{3}"))",
|
||||||
opt.getType(), opt.getCppVariableName(),
|
opt.getType(), opt.getCppVariableName(), opt.getArgument(),
|
||||||
opt.getArgument(), opt.getDescription());
|
opt.getDescription());
|
||||||
if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
|
if (std::optional<StringRef> defaultVal = opt.getDefaultValue())
|
||||||
os << ", ::llvm::cl::init(" << defaultVal << ")";
|
os << ", ::llvm::cl::init(" << defaultVal << ")";
|
||||||
if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
|
if (std::optional<StringRef> additionalFlags = opt.getAdditionalFlags())
|
||||||
@ -284,9 +285,9 @@ static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
|
|||||||
/// Emit the declarations for each of the pass statistics.
|
/// Emit the declarations for each of the pass statistics.
|
||||||
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
|
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
|
||||||
for (const PassStatistic &stat : pass.getStatistics()) {
|
for (const PassStatistic &stat : pass.getStatistics()) {
|
||||||
os << llvm::formatv(
|
os << formatv(" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
|
||||||
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
|
stat.getCppVariableName(), stat.getName(),
|
||||||
stat.getCppVariableName(), stat.getName(), stat.getDescription());
|
stat.getDescription());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,11 +301,10 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
|||||||
os << "#ifdef " << enableVarName << "\n";
|
os << "#ifdef " << enableVarName << "\n";
|
||||||
|
|
||||||
if (emitDefaultConstructors) {
|
if (emitDefaultConstructors) {
|
||||||
os << llvm::formatv(friendDefaultConstructorDeclTemplate, passName);
|
os << formatv(friendDefaultConstructorDeclTemplate, passName);
|
||||||
|
|
||||||
if (emitDefaultConstructorWithOptions)
|
if (emitDefaultConstructorWithOptions)
|
||||||
os << llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate,
|
os << formatv(friendDefaultConstructorWithOptionsDeclTemplate, passName);
|
||||||
passName);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string dependentDialectRegistrations;
|
std::string dependentDialectRegistrations;
|
||||||
@ -313,23 +313,22 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
|||||||
llvm::interleave(
|
llvm::interleave(
|
||||||
pass.getDependentDialects(), dialectsOs,
|
pass.getDependentDialects(), dialectsOs,
|
||||||
[&](StringRef dependentDialect) {
|
[&](StringRef dependentDialect) {
|
||||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
|
||||||
dependentDialect);
|
|
||||||
},
|
},
|
||||||
"\n ");
|
"\n ");
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "namespace impl {\n";
|
os << "namespace impl {\n";
|
||||||
os << llvm::formatv(baseClassBegin, passName, pass.getBaseClass(),
|
os << formatv(baseClassBegin, passName, pass.getBaseClass(),
|
||||||
pass.getArgument(), pass.getSummary(),
|
pass.getArgument(), pass.getSummary(),
|
||||||
dependentDialectRegistrations);
|
dependentDialectRegistrations);
|
||||||
|
|
||||||
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
|
if (ArrayRef<PassOption> options = pass.getOptions(); !options.empty()) {
|
||||||
os.indent(2) << llvm::formatv(
|
os.indent(2) << formatv("{0}Base({0}Options options) : {0}Base() {{\n",
|
||||||
"{0}Base({0}Options options) : {0}Base() {{\n", passName);
|
passName);
|
||||||
|
|
||||||
for (const PassOption &opt : pass.getOptions())
|
for (const PassOption &opt : pass.getOptions())
|
||||||
os.indent(4) << llvm::formatv("{0} = std::move(options.{0});\n",
|
os.indent(4) << formatv("{0} = std::move(options.{0});\n",
|
||||||
opt.getCppVariableName());
|
opt.getCppVariableName());
|
||||||
|
|
||||||
os.indent(2) << "}\n";
|
os.indent(2) << "}\n";
|
||||||
@ -344,21 +343,20 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
|||||||
os << "private:\n";
|
os << "private:\n";
|
||||||
|
|
||||||
if (emitDefaultConstructors) {
|
if (emitDefaultConstructors) {
|
||||||
os << llvm::formatv(friendDefaultConstructorDefTemplate, passName);
|
os << formatv(friendDefaultConstructorDefTemplate, passName);
|
||||||
|
|
||||||
if (!pass.getOptions().empty())
|
if (!pass.getOptions().empty())
|
||||||
os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate,
|
os << formatv(friendDefaultConstructorWithOptionsDefTemplate, passName);
|
||||||
passName);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "};\n";
|
os << "};\n";
|
||||||
os << "} // namespace impl\n";
|
os << "} // namespace impl\n";
|
||||||
|
|
||||||
if (emitDefaultConstructors) {
|
if (emitDefaultConstructors) {
|
||||||
os << llvm::formatv(defaultConstructorDefTemplate, passName);
|
os << formatv(defaultConstructorDefTemplate, passName);
|
||||||
|
|
||||||
if (emitDefaultConstructorWithOptions)
|
if (emitDefaultConstructorWithOptions)
|
||||||
os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName);
|
os << formatv(defaultConstructorWithOptionsDefTemplate, passName);
|
||||||
}
|
}
|
||||||
|
|
||||||
os << "#undef " << enableVarName << "\n";
|
os << "#undef " << enableVarName << "\n";
|
||||||
@ -367,7 +365,7 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
|
|||||||
|
|
||||||
static void emitPass(const Pass &pass, raw_ostream &os) {
|
static void emitPass(const Pass &pass, raw_ostream &os) {
|
||||||
StringRef passName = pass.getDef()->getName();
|
StringRef passName = pass.getDef()->getName();
|
||||||
os << llvm::formatv(passHeader, passName);
|
os << formatv(passHeader, passName);
|
||||||
|
|
||||||
emitPassDecls(pass, os);
|
emitPassDecls(pass, os);
|
||||||
emitPassDefs(pass, os);
|
emitPassDefs(pass, os);
|
||||||
@ -436,12 +434,11 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
|
|||||||
llvm::interleave(
|
llvm::interleave(
|
||||||
pass.getDependentDialects(), dialectsOs,
|
pass.getDependentDialects(), dialectsOs,
|
||||||
[&](StringRef dependentDialect) {
|
[&](StringRef dependentDialect) {
|
||||||
dialectsOs << llvm::formatv(dialectRegistrationTemplate,
|
dialectsOs << formatv(dialectRegistrationTemplate, dependentDialect);
|
||||||
dependentDialect);
|
|
||||||
},
|
},
|
||||||
"\n ");
|
"\n ");
|
||||||
}
|
}
|
||||||
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
|
os << formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
|
||||||
pass.getArgument(), pass.getSummary(),
|
pass.getArgument(), pass.getSummary(),
|
||||||
dependentDialectRegistrations);
|
dependentDialectRegistrations);
|
||||||
emitPassOptionDecls(pass, os);
|
emitPassOptionDecls(pass, os);
|
||||||
@ -449,8 +446,7 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
|
|||||||
os << "};\n";
|
os << "};\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
static void emitPasses(const llvm::RecordKeeper &recordKeeper,
|
static void emitPasses(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||||
raw_ostream &os) {
|
|
||||||
std::vector<Pass> passes = getPasses(recordKeeper);
|
std::vector<Pass> passes = getPasses(recordKeeper);
|
||||||
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
|
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
|
||||||
|
|
||||||
@ -479,7 +475,7 @@ static void emitPasses(const llvm::RecordKeeper &recordKeeper,
|
|||||||
|
|
||||||
static mlir::GenRegistration
|
static mlir::GenRegistration
|
||||||
genPassDecls("gen-pass-decls", "Generate pass declarations",
|
genPassDecls("gen-pass-decls", "Generate pass declarations",
|
||||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
[](const RecordKeeper &records, raw_ostream &os) {
|
||||||
emitPasses(records, os);
|
emitPasses(records, os);
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
@ -98,15 +98,15 @@ public:
|
|||||||
StringRef getMergeInstance() const;
|
StringRef getMergeInstance() const;
|
||||||
|
|
||||||
// Returns the underlying LLVM TableGen Record.
|
// Returns the underlying LLVM TableGen Record.
|
||||||
const llvm::Record *getDef() const { return def; }
|
const Record *getDef() const { return def; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The TableGen definition of this availability.
|
// The TableGen definition of this availability.
|
||||||
const llvm::Record *def;
|
const Record *def;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Availability::Availability(const llvm::Record *def) : def(def) {
|
Availability::Availability(const Record *def) : def(def) {
|
||||||
assert(def->isSubClassOf("Availability") &&
|
assert(def->isSubClassOf("Availability") &&
|
||||||
"must be subclass of TableGen 'Availability' class");
|
"must be subclass of TableGen 'Availability' class");
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user