
Reland https://github.com/llvm/llvm-project/pull/150805 Shared libs build was broken. Add `${dialect_libs}` and `${conversion_libs}` to `MLIRRegisterAllExtensions` because it depends on `registerConvert***ToLLVMInterface` functions.
394 lines
14 KiB
C++
394 lines
14 KiB
C++
//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Main entry function for mlir-rewrite.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/AsmParser/AsmParser.h"
|
|
#include "mlir/AsmParser/AsmParserState.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/InitAllDialects.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Tools/ParseUtilities.h"
|
|
#include "llvm/ADT/RewriteBuffer.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/InitLLVM.h"
|
|
#include "llvm/Support/LineIterator.h"
|
|
#include "llvm/Support/ManagedStatic.h"
|
|
#include "llvm/Support/Regex.h"
|
|
#include "llvm/Support/SourceMgr.h"
|
|
#include "llvm/Support/ToolOutputFile.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
using OperationDefinition = AsmParserState::OperationDefinition;
|
|
|
|
/// Return the source code associated with the OperationDefinition.
|
|
SMRange getOpRange(const OperationDefinition &op) {
|
|
const char *startOp = op.scopeLoc.Start.getPointer();
|
|
const char *endOp = op.scopeLoc.End.getPointer();
|
|
|
|
for (const auto &res : op.resultGroups) {
|
|
SMRange range = res.definition.loc;
|
|
startOp = std::min(startOp, range.Start.getPointer());
|
|
}
|
|
return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
|
|
}
|
|
|
|
/// Helper to simplify rewriting the source file.
|
|
class RewritePad {
|
|
public:
|
|
static std::unique_ptr<RewritePad> init(StringRef inputFilename,
|
|
StringRef outputFilename);
|
|
|
|
/// Return the context the file was parsed into.
|
|
MLIRContext *getContext() { return &context; }
|
|
|
|
/// Return the OperationDefinition's of the operation's parsed.
|
|
iterator_range<AsmParserState::OperationDefIterator> getOpDefs() {
|
|
return asmState.getOpDefs();
|
|
}
|
|
|
|
/// Insert the specified string at the specified location in the original
|
|
/// buffer.
|
|
void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
|
|
rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
|
|
}
|
|
|
|
/// Replace the range of the source text with the corresponding string in the
|
|
/// output.
|
|
void replaceRange(SMRange range, StringRef str) {
|
|
rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
|
|
range.End.getPointer() - range.Start.getPointer(),
|
|
str);
|
|
}
|
|
|
|
/// Replace the range of the operation in the source text with the
|
|
/// corresponding string in the output.
|
|
void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
|
|
replaceRange(getOpRange(opDef), newDef);
|
|
}
|
|
|
|
/// Return the source string corresponding to the source range.
|
|
StringRef getSourceString(SMRange range) {
|
|
return StringRef(range.Start.getPointer(),
|
|
range.End.getPointer() - range.Start.getPointer());
|
|
}
|
|
|
|
/// Return the source string corresponding to operation definition.
|
|
StringRef getSourceString(const OperationDefinition &opDef) {
|
|
auto range = getOpRange(opDef);
|
|
return getSourceString(range);
|
|
}
|
|
|
|
/// Write to stream the result of applying all changes to the
|
|
/// original buffer.
|
|
/// Note that it isn't safe to use this function to overwrite memory mapped
|
|
/// files in-place (PR17960).
|
|
///
|
|
/// The original buffer is not actually changed.
|
|
raw_ostream &write(raw_ostream &stream) const {
|
|
return rewriteBuffer.write(stream);
|
|
}
|
|
|
|
/// Return lines that are purely comments.
|
|
SmallVector<SMRange> getSingleLineComments() {
|
|
unsigned curBuf = sourceMgr.getMainFileID();
|
|
const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
|
|
llvm::line_iterator lineIterator(*curMB);
|
|
SmallVector<SMRange> ret;
|
|
for (; !lineIterator.is_at_end(); ++lineIterator) {
|
|
StringRef trimmed = lineIterator->ltrim();
|
|
if (trimmed.starts_with("//")) {
|
|
ret.emplace_back(
|
|
SMLoc::getFromPointer(trimmed.data()),
|
|
SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
|
|
}
|
|
}
|
|
return ret;
|
|
}
|
|
|
|
/// Return the IR from parsed file.
|
|
Block *getParsed() { return &parsedIR; }
|
|
|
|
/// Return the definition for the given operation, or nullptr if the given
|
|
/// operation does not have a definition.
|
|
const OperationDefinition &getOpDef(Operation *op) const {
|
|
return *asmState.getOpDef(op);
|
|
}
|
|
|
|
private:
|
|
// The context and state required to parse.
|
|
MLIRContext context;
|
|
llvm::SourceMgr sourceMgr;
|
|
DialectRegistry registry;
|
|
FallbackAsmResourceMap fallbackResourceMap;
|
|
|
|
// Storage of textual parsing results.
|
|
AsmParserState asmState;
|
|
|
|
// Parsed IR.
|
|
Block parsedIR;
|
|
|
|
// The RewriteBuffer is doing most of the real work.
|
|
llvm::RewriteBuffer rewriteBuffer;
|
|
|
|
// Start of the original input, used to compute offset.
|
|
const char *start;
|
|
};
|
|
|
|
std::unique_ptr<RewritePad> RewritePad::init(StringRef inputFilename,
|
|
StringRef outputFilename) {
|
|
std::unique_ptr<RewritePad> r = std::make_unique<RewritePad>();
|
|
|
|
// Register all the dialects needed.
|
|
registerAllDialects(r->registry);
|
|
|
|
// Set up the input file.
|
|
std::string errorMessage;
|
|
std::unique_ptr<llvm::MemoryBuffer> file =
|
|
openInputFile(inputFilename, &errorMessage);
|
|
if (!file) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return nullptr;
|
|
}
|
|
r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
|
|
|
|
// Set up the MLIR context and error handling.
|
|
r->context.appendDialectRegistry(r->registry);
|
|
|
|
// Record the start of the buffer to compute offsets with.
|
|
unsigned curBuf = r->sourceMgr.getMainFileID();
|
|
const llvm::MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf);
|
|
r->start = curMB->getBufferStart();
|
|
r->rewriteBuffer.Initialize(curMB->getBuffer());
|
|
|
|
// Parse and populate the AsmParserState.
|
|
ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true,
|
|
&r->fallbackResourceMap);
|
|
// Always allow unregistered.
|
|
r->context.allowUnregisteredDialects(true);
|
|
if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig,
|
|
&r->asmState)))
|
|
return nullptr;
|
|
|
|
return r;
|
|
}
|
|
|
|
/// Return the source code associated with the operation name.
|
|
SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
|
|
|
|
/// Return whether the operation was printed using generic syntax in original
|
|
/// buffer.
|
|
bool isGeneric(const OperationDefinition &op) {
|
|
return op.loc.Start.getPointer()[0] == '"';
|
|
}
|
|
|
|
inline int asMainReturnCode(LogicalResult r) {
|
|
return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
|
|
}
|
|
|
|
/// Reriter function to invoke.
|
|
using RewriterFunction = std::function<mlir::LogicalResult(
|
|
mlir::RewritePad &rewriteState, llvm::raw_ostream &os)>;
|
|
|
|
/// Structure to group information about a rewriter (argument to invoke via
|
|
/// mlir-tblgen, description, and rewriter function).
|
|
class RewriterInfo {
|
|
public:
|
|
/// RewriterInfo constructor should not be invoked directly, instead use
|
|
/// RewriterRegistration or registerRewriter.
|
|
RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter)
|
|
: arg(arg), description(description), rewriter(std::move(rewriter)) {}
|
|
|
|
/// Invokes the rewriter and returns whether the rewriter failed.
|
|
LogicalResult invoke(mlir::RewritePad &rewriteState, raw_ostream &os) const {
|
|
assert(rewriter && "Cannot call rewriter with null rewriter");
|
|
return rewriter(rewriteState, os);
|
|
}
|
|
|
|
/// Returns the command line option that may be passed to 'mlir-rewrite' to
|
|
/// invoke this rewriter.
|
|
StringRef getRewriterArgument() const { return arg; }
|
|
|
|
/// Returns a description for the rewriter.
|
|
StringRef getRewriterDescription() const { return description; }
|
|
|
|
private:
|
|
// The argument with which to invoke the rewriter via mlir-tblgen.
|
|
StringRef arg;
|
|
|
|
// Description of the rewriter.
|
|
StringRef description;
|
|
|
|
// Rewritererator function.
|
|
RewriterFunction rewriter;
|
|
};
|
|
|
|
static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry;
|
|
|
|
/// Adds command line option for each registered rewriter.
|
|
struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> {
|
|
RewriterNameParser(llvm::cl::Option &opt);
|
|
|
|
void printOptionInfo(const llvm::cl::Option &o,
|
|
size_t globalWidth) const override;
|
|
};
|
|
|
|
/// RewriterRegistration provides a global initializer that registers a rewriter
|
|
/// function.
|
|
struct RewriterRegistration {
|
|
RewriterRegistration(StringRef arg, StringRef description,
|
|
const RewriterFunction &function);
|
|
};
|
|
|
|
RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description,
|
|
const RewriterFunction &function) {
|
|
rewriterRegistry->emplace_back(arg, description, function);
|
|
}
|
|
|
|
RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt)
|
|
: llvm::cl::parser<const RewriterInfo *>(opt) {
|
|
for (const auto &kv : *rewriterRegistry) {
|
|
addLiteralOption(kv.getRewriterArgument(), &kv,
|
|
kv.getRewriterDescription());
|
|
}
|
|
}
|
|
|
|
void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o,
|
|
size_t globalWidth) const {
|
|
RewriterNameParser *tp = const_cast<RewriterNameParser *>(this);
|
|
llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
|
|
[](const RewriterNameParser::OptionInfo *vT1,
|
|
const RewriterNameParser::OptionInfo *vT2) {
|
|
return vT1->Name.compare(vT2->Name);
|
|
});
|
|
using llvm::cl::parser;
|
|
parser<const RewriterInfo *>::printOptionInfo(o, globalWidth);
|
|
}
|
|
|
|
} // namespace mlir
|
|
|
|
// TODO: Make these injectable too in non-global way.
|
|
static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"};
|
|
static llvm::cl::opt<std::string> simpleRenameOpName{
|
|
"simple-rename-op-name", llvm::cl::desc("Name of op to match on"),
|
|
llvm::cl::cat(clSimpleRenameCategory)};
|
|
static llvm::cl::opt<std::string> simpleRenameMatch{
|
|
"simple-rename-match", llvm::cl::desc("Match string for rename"),
|
|
llvm::cl::cat(clSimpleRenameCategory)};
|
|
static llvm::cl::opt<std::string> simpleRenameReplace{
|
|
"simple-rename-replace", llvm::cl::desc("Replace string for rename"),
|
|
llvm::cl::cat(clSimpleRenameCategory)};
|
|
|
|
// Rewriter that does simple renames.
|
|
LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) {
|
|
StringRef opName = simpleRenameOpName;
|
|
StringRef match = simpleRenameMatch;
|
|
StringRef replace = simpleRenameReplace;
|
|
llvm::Regex regex(match);
|
|
|
|
rewriteState.getParsed()->walk([&](Operation *op) {
|
|
if (op->getName().getStringRef() != opName)
|
|
return;
|
|
|
|
const OperationDefinition &opDef = rewriteState.getOpDef(op);
|
|
SMRange range = getOpRange(opDef);
|
|
// This is a little bit overkill for simple.
|
|
std::string str = regex.sub(replace, rewriteState.getSourceString(range));
|
|
rewriteState.replaceRange(range, str);
|
|
});
|
|
return success();
|
|
}
|
|
|
|
static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
|
|
"Perform a simple rename",
|
|
simpleRename);
|
|
|
|
// Rewriter that insert range markers.
|
|
LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) {
|
|
for (const auto &it : rewriteState.getOpDefs()) {
|
|
auto [startOp, endOp] = getOpRange(it);
|
|
|
|
rewriteState.insertText(startOp, "<");
|
|
rewriteState.insertText(endOp, ">");
|
|
|
|
auto nameRange = getOpNameRange(it);
|
|
|
|
if (isGeneric(it)) {
|
|
rewriteState.insertText(nameRange.Start, "[");
|
|
rewriteState.insertText(nameRange.End, "]");
|
|
} else {
|
|
rewriteState.insertText(nameRange.Start, "![");
|
|
rewriteState.insertText(nameRange.End, "]!");
|
|
}
|
|
}
|
|
|
|
// Highlight all comment lines.
|
|
// TODO: Could be replaced if this is kept in memory.
|
|
for (auto commentLine : rewriteState.getSingleLineComments()) {
|
|
rewriteState.insertText(commentLine.Start, "{");
|
|
rewriteState.insertText(commentLine.End, "}");
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static mlir::RewriterRegistration
|
|
rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges);
|
|
|
|
int main(int argc, char **argv) {
|
|
llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
|
|
llvm::cl::desc("<input file>"),
|
|
llvm::cl::init("-"));
|
|
|
|
llvm::cl::opt<std::string> outputFilename(
|
|
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
|
|
llvm::cl::init("-"));
|
|
|
|
llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser>
|
|
rewriter("", llvm::cl::desc("Rewriter to run"));
|
|
|
|
std::string helpHeader = "mlir-rewrite";
|
|
|
|
llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader);
|
|
|
|
// If no rewriter has been selected, exit with error code. Could also just
|
|
// return but its unlikely this was intentionally being used as `cp`.
|
|
if (!rewriter) {
|
|
llvm::errs() << "No rewriter selected!\n";
|
|
return mlir::asMainReturnCode(mlir::failure());
|
|
}
|
|
|
|
// Set up rewrite buffer.
|
|
auto rewriterOr = RewritePad::init(inputFilename, outputFilename);
|
|
if (!rewriterOr)
|
|
return mlir::asMainReturnCode(mlir::failure());
|
|
|
|
// Set up the output file.
|
|
std::string errorMessage;
|
|
auto output = openOutputFile(outputFilename, &errorMessage);
|
|
if (!output) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return mlir::asMainReturnCode(mlir::failure());
|
|
}
|
|
|
|
LogicalResult result = rewriter->invoke(*rewriterOr, output->os());
|
|
if (succeeded(result)) {
|
|
rewriterOr->write(output->os());
|
|
output->keep();
|
|
}
|
|
return mlir::asMainReturnCode(result);
|
|
}
|