This commits adds a C++ generator to PDLL that generates wrapper PDL patterns directly usable in C++ code, and also generates the definitions of native constraints/rewrites that have code bodies specified in PDLL. This generator is effectively the PDLL equivalent of the current DRR generator, and will allow easy replacement of DRR patterns with PDLL patterns. A followup will start to utilize this for end-to-end integration testing and show case how to use this as a drop-in replacement for DRR tablegen usage. Differential Revision: https://reviews.llvm.org/D119781
220 lines
8.8 KiB
C++
220 lines
8.8 KiB
C++
//===- CPPGen.cpp ---------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This files contains a PDLL generator that outputs C++ code that defines PDLL
|
|
// patterns as individual C++ PDLPatternModules for direct use in native code,
|
|
// and also defines any native constraints whose bodies were defined in PDLL.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Tools/PDLL/CodeGen/CPPGen.h"
|
|
#include "mlir/Dialect/PDL/IR/PDL.h"
|
|
#include "mlir/Dialect/PDL/IR/PDLOps.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/Tools/PDLL/AST/Nodes.h"
|
|
#include "llvm/ADT/SmallString.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/ErrorHandling.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::pdll;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CodeGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
class CodeGen {
|
|
public:
|
|
CodeGen(raw_ostream &os) : os(os) {}
|
|
|
|
/// Generate C++ code for the given PDL pattern module.
|
|
void generate(const ast::Module &astModule, ModuleOp module);
|
|
|
|
private:
|
|
void generate(pdl::PatternOp pattern, StringRef patternName,
|
|
StringSet<> &nativeFunctions);
|
|
|
|
/// Generate C++ code for all user defined constraints and rewrites with
|
|
/// native code.
|
|
void generateConstraintAndRewrites(const ast::Module &astModule,
|
|
ModuleOp module,
|
|
StringSet<> &nativeFunctions);
|
|
void generate(const ast::UserConstraintDecl *decl,
|
|
StringSet<> &nativeFunctions);
|
|
void generate(const ast::UserRewriteDecl *decl, StringSet<> &nativeFunctions);
|
|
void generateConstraintOrRewrite(StringRef name, bool isConstraint,
|
|
ArrayRef<ast::VariableDecl *> inputs,
|
|
StringRef codeBlock,
|
|
StringSet<> &nativeFunctions);
|
|
|
|
/// The stream to output to.
|
|
raw_ostream &os;
|
|
};
|
|
} // namespace
|
|
|
|
void CodeGen::generate(const ast::Module &astModule, ModuleOp module) {
|
|
SetVector<std::string, SmallVector<std::string>, StringSet<>> patternNames;
|
|
StringSet<> nativeFunctions;
|
|
|
|
// Generate code for any native functions within the module.
|
|
generateConstraintAndRewrites(astModule, module, nativeFunctions);
|
|
|
|
os << "namespace {\n";
|
|
std::string basePatternName = "GeneratedPDLLPattern";
|
|
int patternIndex = 0;
|
|
for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
|
|
// If the pattern has a name, use that. Otherwise, generate a unique name.
|
|
if (Optional<StringRef> patternName = pattern.sym_name()) {
|
|
patternNames.insert(patternName->str());
|
|
} else {
|
|
std::string name;
|
|
do {
|
|
name = (basePatternName + Twine(patternIndex++)).str();
|
|
} while (!patternNames.insert(name));
|
|
}
|
|
|
|
generate(pattern, patternNames.back(), nativeFunctions);
|
|
}
|
|
os << "} // end namespace\n\n";
|
|
|
|
// Emit function to add the generated matchers to the pattern list.
|
|
os << "static void LLVM_ATTRIBUTE_UNUSED populateGeneratedPDLLPatterns("
|
|
"::mlir::RewritePatternSet &patterns) {\n";
|
|
for (const auto &name : patternNames)
|
|
os << " patterns.add<" << name << ">(patterns.getContext());\n";
|
|
os << "}\n";
|
|
}
|
|
|
|
void CodeGen::generate(pdl::PatternOp pattern, StringRef patternName,
|
|
StringSet<> &nativeFunctions) {
|
|
const char *patternClassStartStr = R"(
|
|
struct {0} : ::mlir::PDLPatternModule {{
|
|
{0}(::mlir::MLIRContext *context)
|
|
: ::mlir::PDLPatternModule(::mlir::parseSourceString<::mlir::ModuleOp>(
|
|
)";
|
|
os << llvm::formatv(patternClassStartStr, patternName);
|
|
|
|
os << "R\"mlir(";
|
|
pattern->print(os, OpPrintingFlags().enableDebugInfo());
|
|
os << "\n )mlir\", context)) {\n";
|
|
|
|
// Register any native functions used within the pattern.
|
|
StringSet<> registeredNativeFunctions;
|
|
auto checkRegisterNativeFn = [&](StringRef fnName, StringRef fnType) {
|
|
if (!nativeFunctions.count(fnName) ||
|
|
!registeredNativeFunctions.insert(fnName).second)
|
|
return;
|
|
os << " register" << fnType << "Function(\"" << fnName << "\", "
|
|
<< fnName << "PDLFn);\n";
|
|
};
|
|
pattern.walk([&](Operation *op) {
|
|
if (auto constraintOp = dyn_cast<pdl::ApplyNativeConstraintOp>(op))
|
|
checkRegisterNativeFn(constraintOp.name(), "Constraint");
|
|
else if (auto rewriteOp = dyn_cast<pdl::ApplyNativeRewriteOp>(op))
|
|
checkRegisterNativeFn(rewriteOp.name(), "Rewrite");
|
|
});
|
|
os << " }\n};\n\n";
|
|
}
|
|
|
|
void CodeGen::generateConstraintAndRewrites(const ast::Module &astModule,
|
|
ModuleOp module,
|
|
StringSet<> &nativeFunctions) {
|
|
// First check to see which constraints and rewrites are actually referenced
|
|
// in the module.
|
|
StringSet<> usedFns;
|
|
module.walk([&](Operation *op) {
|
|
TypeSwitch<Operation *>(op)
|
|
.Case<pdl::ApplyNativeConstraintOp, pdl::ApplyNativeRewriteOp>(
|
|
[&](auto op) { usedFns.insert(op.name()); });
|
|
});
|
|
|
|
for (const ast::Decl *decl : astModule.getChildren()) {
|
|
TypeSwitch<const ast::Decl *>(decl)
|
|
.Case<ast::UserConstraintDecl, ast::UserRewriteDecl>(
|
|
[&](const auto *decl) {
|
|
// We only generate code for inline native decls that have been
|
|
// referenced.
|
|
if (decl->getCodeBlock() &&
|
|
usedFns.contains(decl->getName().getName()))
|
|
this->generate(decl, nativeFunctions);
|
|
});
|
|
}
|
|
}
|
|
|
|
void CodeGen::generate(const ast::UserConstraintDecl *decl,
|
|
StringSet<> &nativeFunctions) {
|
|
return generateConstraintOrRewrite(decl->getName().getName(),
|
|
/*isConstraint=*/true, decl->getInputs(),
|
|
*decl->getCodeBlock(), nativeFunctions);
|
|
}
|
|
|
|
void CodeGen::generate(const ast::UserRewriteDecl *decl,
|
|
StringSet<> &nativeFunctions) {
|
|
return generateConstraintOrRewrite(decl->getName().getName(),
|
|
/*isConstraint=*/false, decl->getInputs(),
|
|
*decl->getCodeBlock(), nativeFunctions);
|
|
}
|
|
|
|
void CodeGen::generateConstraintOrRewrite(StringRef name, bool isConstraint,
|
|
ArrayRef<ast::VariableDecl *> inputs,
|
|
StringRef codeBlock,
|
|
StringSet<> &nativeFunctions) {
|
|
nativeFunctions.insert(name);
|
|
|
|
// TODO: Should there be something explicit for handling optionality?
|
|
auto getCppType = [&](ast::Type type) -> StringRef {
|
|
return llvm::TypeSwitch<ast::Type, StringRef>(type)
|
|
.Case([&](ast::AttributeType) { return "::mlir::Attribute"; })
|
|
.Case([&](ast::OperationType) {
|
|
// TODO: Allow using the derived Op class when possible.
|
|
return "::mlir::Operation *";
|
|
})
|
|
.Case([&](ast::TypeType) { return "::mlir::Type"; })
|
|
.Case([&](ast::ValueType) { return "::mlir::Value"; })
|
|
.Case([&](ast::TypeRangeType) { return "::mlir::TypeRange"; })
|
|
.Case([&](ast::ValueRangeType) { return "::mlir::ValueRange"; });
|
|
};
|
|
|
|
// FIXME: We currently do not have a modeling for the "constant params"
|
|
// support PDL provides. We should either figure out a modeling for this, or
|
|
// refactor the support within PDL to be something a bit more reasonable for
|
|
// what we need as a frontend.
|
|
os << "static " << (isConstraint ? "::mlir::LogicalResult " : "void ") << name
|
|
<< "PDLFn(::llvm::ArrayRef<::mlir::PDLValue> values, "
|
|
"::mlir::ArrayAttr constParams, ::mlir::PatternRewriter &rewriter"
|
|
<< (isConstraint ? "" : ", ::mlir::PDLResultList &results") << ") {\n";
|
|
|
|
const char *argumentInitStr = R"(
|
|
{0} {1} = {{};
|
|
if (values[{2}])
|
|
{1} = values[{2}].cast<{0}>();
|
|
(void){1};
|
|
)";
|
|
for (const auto &it : llvm::enumerate(inputs)) {
|
|
const ast::VariableDecl *input = it.value();
|
|
os << llvm::formatv(argumentInitStr, getCppType(input->getType()),
|
|
input->getName().getName(), it.index());
|
|
}
|
|
|
|
os << " " << codeBlock.trim() << "\n}\n";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CPPGen
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlir::pdll::codegenPDLLToCPP(const ast::Module &astModule, ModuleOp module,
|
|
raw_ostream &os) {
|
|
CodeGen codegen(os);
|
|
codegen.generate(astModule, module);
|
|
}
|