llvm-project/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp
Jacques Pienaar 0911558005 [mlir] Dialect type/attr bytecode read/write generator.
Tool to help generate dialect bytecode Attribute & Type reader/writing.
Show usage by flipping builtin dialect.

It helps reduce boilerplate when writing dialect bytecode attribute and
type readers/writers. It is not an attempt at a generic spec mechanism
but rather practically focussing on boilerplate reduction while also
considering that it need not be the only in memory format and make it
relatively easy to change.

There should be some cleanup in follow up as we expand to more dialects.

Differential Revision: https://reviews.llvm.org/D144820
2023-04-24 11:53:58 -07:00

468 lines
16 KiB
C++

//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include <regex>
using namespace llvm;
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode");
static llvm::cl::opt<std::string>
selectedBcDialect("bytecode-dialect",
llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated);
namespace {
/// Helper class to generate C++ bytecode parser helpers.
class Generator {
public:
Generator(raw_ostream &output) : output(output) {}
/// Returns whether successfully emitted attribute/type parsers.
void emitParse(StringRef kind, Record &x);
/// Returns whether successfully emitted attribute/type printers.
void emitPrint(StringRef kind, StringRef type,
ArrayRef<std::pair<int64_t, Record *>> vec);
/// Emits parse dispatch table.
void emitParseDispatch(StringRef kind, ArrayRef<Record *> vec);
/// Emits print dispatch table.
void emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec);
private:
/// Emits parse calls to construct given kind.
void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder,
ArrayRef<Init *> args, ArrayRef<std::string> argNames,
StringRef failure, mlir::raw_indented_ostream &ios);
/// Emits print instructions.
void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent,
StringRef name, mlir::raw_indented_ostream &ios);
raw_ostream &output;
};
} // namespace
/// Helper to replace set of from strings to target in `s`.
/// Assumed: non-overlapping replacements.
static std::string format(StringRef templ,
std::map<std::string, std::string> &&map) {
std::string s = templ.str();
for (const auto &[from, to] : map)
// All replacements start with $, don't treat as anchor.
s = std::regex_replace(s, std::regex("\\" + from), to);
return s;
}
/// Return string with first character capitalized.
static std::string capitalize(StringRef str) {
return ((Twine)toUpper(str[0]) + str.drop_front()).str();
}
/// Return the C++ type for the given record.
static std::string getCType(Record *def) {
std::string format = "{0}";
if (def->isSubClassOf("Array")) {
def = def->getValueAsDef("elemT");
format = "SmallVector<{0}>";
}
StringRef cType = def->getValueAsString("cType");
if (cType.empty()) {
if (def->isAnonymous())
PrintFatalError(def->getLoc(), "Unable to determine cType");
return formatv(format.c_str(), def->getName().str());
}
return formatv(format.c_str(), cType.str());
}
void Generator::emitParseDispatch(StringRef kind, ArrayRef<Record *> vec) {
mlir::raw_indented_ostream os(output);
char const *head =
R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))";
os << formatv(head, capitalize(kind));
auto funScope = os.scope(" {\n", "}\n\n");
os << "uint64_t kind;\n";
os << "if (failed(reader.readVarInt(kind)))\n"
<< " return " << capitalize(kind) << "();\n";
os << "switch (kind) ";
{
auto switchScope = os.scope("{\n", "}\n");
for (const auto &it : llvm::enumerate(vec)) {
os << formatv("case {1}:\n return read{0}(context, reader);\n",
it.value()->getName(), it.index());
}
os << "default:\n"
<< " reader.emitError() << \"unknown attribute code: \" "
<< "<< kind;\n"
<< " return " << capitalize(kind) << "();\n";
}
os << "return " << capitalize(kind) << "();\n";
}
void Generator::emitParse(StringRef kind, Record &x) {
char const *head =
R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )";
mlir::raw_indented_ostream os(output);
std::string returnType = getCType(&x);
os << formatv(head, returnType, x.getName());
DagInit *members = x.getValueAsDag("members");
SmallVector<std::string> argNames =
llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) {
return init->getAsUnquotedString();
}));
StringRef builder = x.getValueAsString("cBuilder");
emitParseHelper(kind, returnType, builder, members->getArgs(), argNames,
returnType + "()", os);
os << "\n\n";
}
void printParseConditional(mlir::raw_indented_ostream &ios,
ArrayRef<Init *> args,
ArrayRef<std::string> argNames) {
ios << "if ";
auto parenScope = ios.scope("(", ") {");
ios.indent();
auto listHelperName = [](StringRef name) {
return formatv("read{0}", capitalize(name));
};
auto parsedArgs =
llvm::to_vector(make_filter_range(args, [](Init *const attr) {
Record *def = cast<DefInit>(attr)->getDef();
if (def->isSubClassOf("Array"))
return true;
return !def->getValueAsString("cParser").empty();
}));
interleave(
zip(parsedArgs, argNames),
[&](std::tuple<llvm::Init *&, const std::string &> it) {
Record *attr = cast<DefInit>(std::get<0>(it))->getDef();
std::string parser;
if (auto optParser = attr->getValueAsOptionalString("cParser")) {
parser = *optParser;
} else if (attr->isSubClassOf("Array")) {
Record *def = attr->getValueAsDef("elemT");
bool composite = def->isSubClassOf("CompositeBytecode");
if (!composite && def->isSubClassOf("AttributeKind"))
parser = "succeeded($_reader.readAttributes($_var))";
else if (!composite && def->isSubClassOf("TypeKind"))
parser = "succeeded($_reader.readTypes($_var))";
else
parser = ("succeeded($_reader.readList($_var, " +
listHelperName(std::get<1>(it)) + "))")
.str();
} else {
PrintFatalError(attr->getLoc(), "No parser specified");
}
std::string type = getCType(attr);
ios << format(parser, {{"$_reader", "reader"},
{"$_resultType", type},
{"$_var", std::get<1>(it)}});
},
[&]() { ios << " &&\n"; });
}
void Generator::emitParseHelper(StringRef kind, StringRef returnType,
StringRef builder, ArrayRef<Init *> args,
ArrayRef<std::string> argNames,
StringRef failure,
mlir::raw_indented_ostream &ios) {
auto funScope = ios.scope("{\n", "}");
if (args.empty()) {
ios << formatv("return get<{0}>(context);\n", returnType);
return;
}
// Print decls.
std::string lastCType = "";
for (auto [arg, name] : zip(args, argNames)) {
DefInit *first = dyn_cast<DefInit>(arg);
if (!first)
PrintFatalError("Unexpected type for " + name);
Record *def = first->getDef();
// Create variable decls, if there are a block of same type then create
// comma separated list of them.
std::string cType = getCType(def);
if (lastCType == cType) {
ios << ", ";
} else {
if (!lastCType.empty())
ios << ";\n";
ios << cType << " ";
}
ios << name;
lastCType = cType;
}
ios << ";\n";
// Returns the name of the helper used in list parsing. E.g., the name of the
// lambda passed to array parsing.
auto listHelperName = [](StringRef name) {
return formatv("read{0}", capitalize(name));
};
// Emit list helper functions.
for (auto [arg, name] : zip(args, argNames)) {
Record *attr = cast<DefInit>(arg)->getDef();
if (!attr->isSubClassOf("Array"))
continue;
// TODO: Dedupe readers.
Record *def = attr->getValueAsDef("elemT");
if (!def->isSubClassOf("CompositeBytecode") &&
(def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind")))
continue;
std::string returnType = getCType(def);
ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<"
<< returnType << "> ";
SmallVector<Init *> args;
SmallVector<std::string> argNames;
if (def->isSubClassOf("CompositeBytecode")) {
DagInit *members = def->getValueAsDag("members");
args = llvm::to_vector(members->getArgs());
argNames = llvm::to_vector(
map_range(members->getArgNames(), [](StringInit *init) {
return init->getAsUnquotedString();
}));
} else {
args = {def->getDefInit()};
argNames = {"temp"};
}
StringRef builder = def->getValueAsString("cBuilder");
emitParseHelper(kind, returnType, builder, args, argNames, "failure()",
ios);
ios << ";\n";
}
// Print parse conditional.
printParseConditional(ios, args, argNames);
// Compute args to pass to create method.
auto passedArgs = llvm::to_vector(make_filter_range(
argNames, [](StringRef str) { return !str.starts_with("_"); }));
std::string argStr;
raw_string_ostream argStream(argStr);
interleaveComma(passedArgs, argStream,
[&](const std::string &str) { argStream << str; });
// Return the invoked constructor.
ios << "\nreturn "
<< format(builder, {{"$_resultType", returnType.str()},
{"$_args", argStream.str()}})
<< ";\n";
ios.unindent();
// TODO: Emit error in debug.
// This assumes the result types in error case can always be empty
// constructed.
ios << "}\nreturn " << failure << ";\n";
}
void Generator::emitPrint(StringRef kind, StringRef type,
ArrayRef<std::pair<int64_t, Record *>> vec) {
char const *head =
R"(static void write({0} {1}, DialectBytecodeWriter &writer) )";
mlir::raw_indented_ostream os(output);
os << formatv(head, type, kind);
auto funScope = os.scope("{\n", "}\n\n");
// Check that predicates specified if multiple bytecode instances.
for (llvm::Record *rec : make_second_range(vec)) {
StringRef pred = rec->getValueAsString("printerPredicate");
if (vec.size() > 1 && pred.empty()) {
for (auto [index, rec] : vec) {
(void)index;
StringRef pred = rec->getValueAsString("printerPredicate");
if (vec.size() > 1 && pred.empty())
PrintError(rec->getLoc(),
"Requires parsing predicate given common cType");
}
PrintFatalError("Unspecified for shared cType " + type);
}
}
for (auto [index, rec] : vec) {
StringRef pred = rec->getValueAsString("printerPredicate");
if (!pred.empty()) {
os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n";
os.indent();
}
os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index
<< ");\n";
auto *members = rec->getValueAsDag("members");
for (auto [arg, name] :
llvm::zip(members->getArgs(), members->getArgNames())) {
DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
Record *memberRec = def->getDef();
emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os);
}
if (!pred.empty()) {
os.unindent();
os << "}\n";
}
}
}
void Generator::emitPrintHelper(Record *memberRec, StringRef kind,
StringRef parent, StringRef name,
mlir::raw_indented_ostream &ios) {
std::string getter;
if (auto cGetter = memberRec->getValueAsOptionalString("cGetter");
cGetter && !cGetter->empty()) {
getter = format(
*cGetter,
{{"$_attrType", parent.str()},
{"$_member", name.str()},
{"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}});
} else {
getter =
formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true))
.str();
}
if (memberRec->isSubClassOf("Array")) {
Record *def = memberRec->getValueAsDef("elemT");
if (!def->isSubClassOf("CompositeBytecode")) {
if (def->isSubClassOf("AttributeKind")) {
ios << "writer.writeAttributes(" << getter << ");\n";
return;
}
if (def->isSubClassOf("TypeKind")) {
ios << "writer.writeTypes(" << getter << ");\n";
return;
}
}
std::string returnType = getCType(def);
ios << "writer.writeList(" << getter << ", [&](" << returnType << " "
<< kind << ") ";
auto lambdaScope = ios.scope("{\n", "});\n");
return emitPrintHelper(def, kind, kind, kind, ios);
}
if (memberRec->isSubClassOf("CompositeBytecode")) {
auto *members = memberRec->getValueAsDag("members");
for (auto [arg, argName] :
zip(members->getArgs(), members->getArgNames())) {
DefInit *def = dyn_cast<DefInit>(arg);
assert(def);
emitPrintHelper(def->getDef(), kind, parent,
argName->getAsUnquotedString(), ios);
}
}
if (std::string printer = memberRec->getValueAsString("cPrinter").str();
!printer.empty())
ios << format(printer, {{"$_writer", "writer"},
{"$_name", kind.str()},
{"$_getter", getter}})
<< ";\n";
}
void Generator::emitPrintDispatch(StringRef kind, ArrayRef<std::string> vec) {
mlir::raw_indented_ostream os(output);
char const *head = R"(static LogicalResult write{0}({0} {1},
DialectBytecodeWriter &writer))";
os << formatv(head, capitalize(kind), kind);
auto funScope = os.scope(" {\n", "}\n\n");
os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind
<< ")";
auto switchScope = os.scope("", "");
for (StringRef type : vec) {
os << "\n.Case([&](" << type << " t)";
auto caseScope = os.scope(" {\n", "})");
os << "return write(t, writer), success();\n";
}
os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n";
}
namespace {
/// Container of Attribute or Type for Dialect.
struct AttrOrType {
std::vector<Record *> attr, type;
};
} // namespace
static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) {
MapVector<StringRef, AttrOrType> dialectAttrOrType;
for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) {
if (!selectedBcDialect.empty() &&
it->getValueAsString("dialect") != selectedBcDialect)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].attr =
it->getValueAsListOfDefs("elems");
}
for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) {
if (!selectedBcDialect.empty() &&
it->getValueAsString("dialect") != selectedBcDialect)
continue;
dialectAttrOrType[it->getValueAsString("dialect")].type =
it->getValueAsListOfDefs("elems");
}
if (dialectAttrOrType.size() != 1)
PrintFatalError("Single dialect per invocation required (either only "
"one in input file or specified via dialect option)");
auto it = dialectAttrOrType.front();
Generator gen(os);
SmallVector<std::vector<Record *> *, 2> vecs;
SmallVector<std::string, 2> kinds;
vecs.push_back(&it.second.attr);
kinds.push_back("attribute");
vecs.push_back(&it.second.type);
kinds.push_back("type");
for (auto [vec, kind] : zip(vecs, kinds)) {
// Handle Attribute/Type emission.
std::map<std::string, std::vector<std::pair<int64_t, Record *>>> perType;
for (auto kt : llvm::enumerate(*vec))
perType[getCType(kt.value())].emplace_back(kt.index(), kt.value());
for (const auto &jt : perType) {
for (auto kt : jt.second)
gen.emitParse(kind, *std::get<1>(kt));
gen.emitPrint(kind, jt.first, jt.second);
}
gen.emitParseDispatch(kind, *vec);
SmallVector<std::string> types;
for (const auto &it : perType) {
types.push_back(it.first);
}
gen.emitPrintDispatch(kind, types);
}
return false;
}
static mlir::GenRegistration
genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers",
[](const RecordKeeper &records, raw_ostream &os) {
return emitBCRW(records, os);
});