Ivan Butygin e68a20e0b7
[mlir] Reland Move InitAll*** implementation into static library (#151150)
Reland https://github.com/llvm/llvm-project/pull/150805

Shared libs build was broken.

Add `${dialect_libs}` and `${conversion_libs}` to
`MLIRRegisterAllExtensions` because it depends on
`registerConvert***ToLLVMInterface` functions.
2025-07-29 18:15:33 +03:00

394 lines
14 KiB
C++

//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Main entry function for mlir-rewrite.
//
//===----------------------------------------------------------------------===//
#include "mlir/AsmParser/AsmParser.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Tools/ParseUtilities.h"
#include "llvm/ADT/RewriteBuffer.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/LineIterator.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
namespace mlir {
using OperationDefinition = AsmParserState::OperationDefinition;
/// Return the source code associated with the OperationDefinition.
SMRange getOpRange(const OperationDefinition &op) {
const char *startOp = op.scopeLoc.Start.getPointer();
const char *endOp = op.scopeLoc.End.getPointer();
for (const auto &res : op.resultGroups) {
SMRange range = res.definition.loc;
startOp = std::min(startOp, range.Start.getPointer());
}
return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
}
/// Helper to simplify rewriting the source file.
class RewritePad {
public:
static std::unique_ptr<RewritePad> init(StringRef inputFilename,
StringRef outputFilename);
/// Return the context the file was parsed into.
MLIRContext *getContext() { return &context; }
/// Return the OperationDefinition's of the operation's parsed.
iterator_range<AsmParserState::OperationDefIterator> getOpDefs() {
return asmState.getOpDefs();
}
/// Insert the specified string at the specified location in the original
/// buffer.
void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
}
/// Replace the range of the source text with the corresponding string in the
/// output.
void replaceRange(SMRange range, StringRef str) {
rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
range.End.getPointer() - range.Start.getPointer(),
str);
}
/// Replace the range of the operation in the source text with the
/// corresponding string in the output.
void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
replaceRange(getOpRange(opDef), newDef);
}
/// Return the source string corresponding to the source range.
StringRef getSourceString(SMRange range) {
return StringRef(range.Start.getPointer(),
range.End.getPointer() - range.Start.getPointer());
}
/// Return the source string corresponding to operation definition.
StringRef getSourceString(const OperationDefinition &opDef) {
auto range = getOpRange(opDef);
return getSourceString(range);
}
/// Write to stream the result of applying all changes to the
/// original buffer.
/// Note that it isn't safe to use this function to overwrite memory mapped
/// files in-place (PR17960).
///
/// The original buffer is not actually changed.
raw_ostream &write(raw_ostream &stream) const {
return rewriteBuffer.write(stream);
}
/// Return lines that are purely comments.
SmallVector<SMRange> getSingleLineComments() {
unsigned curBuf = sourceMgr.getMainFileID();
const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
llvm::line_iterator lineIterator(*curMB);
SmallVector<SMRange> ret;
for (; !lineIterator.is_at_end(); ++lineIterator) {
StringRef trimmed = lineIterator->ltrim();
if (trimmed.starts_with("//")) {
ret.emplace_back(
SMLoc::getFromPointer(trimmed.data()),
SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
}
}
return ret;
}
/// Return the IR from parsed file.
Block *getParsed() { return &parsedIR; }
/// Return the definition for the given operation, or nullptr if the given
/// operation does not have a definition.
const OperationDefinition &getOpDef(Operation *op) const {
return *asmState.getOpDef(op);
}
private:
// The context and state required to parse.
MLIRContext context;
llvm::SourceMgr sourceMgr;
DialectRegistry registry;
FallbackAsmResourceMap fallbackResourceMap;
// Storage of textual parsing results.
AsmParserState asmState;
// Parsed IR.
Block parsedIR;
// The RewriteBuffer is doing most of the real work.
llvm::RewriteBuffer rewriteBuffer;
// Start of the original input, used to compute offset.
const char *start;
};
std::unique_ptr<RewritePad> RewritePad::init(StringRef inputFilename,
StringRef outputFilename) {
std::unique_ptr<RewritePad> r = std::make_unique<RewritePad>();
// Register all the dialects needed.
registerAllDialects(r->registry);
// Set up the input file.
std::string errorMessage;
std::unique_ptr<llvm::MemoryBuffer> file =
openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return nullptr;
}
r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
// Set up the MLIR context and error handling.
r->context.appendDialectRegistry(r->registry);
// Record the start of the buffer to compute offsets with.
unsigned curBuf = r->sourceMgr.getMainFileID();
const llvm::MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf);
r->start = curMB->getBufferStart();
r->rewriteBuffer.Initialize(curMB->getBuffer());
// Parse and populate the AsmParserState.
ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true,
&r->fallbackResourceMap);
// Always allow unregistered.
r->context.allowUnregisteredDialects(true);
if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig,
&r->asmState)))
return nullptr;
return r;
}
/// Return the source code associated with the operation name.
SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
/// Return whether the operation was printed using generic syntax in original
/// buffer.
bool isGeneric(const OperationDefinition &op) {
return op.loc.Start.getPointer()[0] == '"';
}
inline int asMainReturnCode(LogicalResult r) {
return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
}
/// Reriter function to invoke.
using RewriterFunction = std::function<mlir::LogicalResult(
mlir::RewritePad &rewriteState, llvm::raw_ostream &os)>;
/// Structure to group information about a rewriter (argument to invoke via
/// mlir-tblgen, description, and rewriter function).
class RewriterInfo {
public:
/// RewriterInfo constructor should not be invoked directly, instead use
/// RewriterRegistration or registerRewriter.
RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter)
: arg(arg), description(description), rewriter(std::move(rewriter)) {}
/// Invokes the rewriter and returns whether the rewriter failed.
LogicalResult invoke(mlir::RewritePad &rewriteState, raw_ostream &os) const {
assert(rewriter && "Cannot call rewriter with null rewriter");
return rewriter(rewriteState, os);
}
/// Returns the command line option that may be passed to 'mlir-rewrite' to
/// invoke this rewriter.
StringRef getRewriterArgument() const { return arg; }
/// Returns a description for the rewriter.
StringRef getRewriterDescription() const { return description; }
private:
// The argument with which to invoke the rewriter via mlir-tblgen.
StringRef arg;
// Description of the rewriter.
StringRef description;
// Rewritererator function.
RewriterFunction rewriter;
};
static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry;
/// Adds command line option for each registered rewriter.
struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> {
RewriterNameParser(llvm::cl::Option &opt);
void printOptionInfo(const llvm::cl::Option &o,
size_t globalWidth) const override;
};
/// RewriterRegistration provides a global initializer that registers a rewriter
/// function.
struct RewriterRegistration {
RewriterRegistration(StringRef arg, StringRef description,
const RewriterFunction &function);
};
RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description,
const RewriterFunction &function) {
rewriterRegistry->emplace_back(arg, description, function);
}
RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt)
: llvm::cl::parser<const RewriterInfo *>(opt) {
for (const auto &kv : *rewriterRegistry) {
addLiteralOption(kv.getRewriterArgument(), &kv,
kv.getRewriterDescription());
}
}
void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o,
size_t globalWidth) const {
RewriterNameParser *tp = const_cast<RewriterNameParser *>(this);
llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
[](const RewriterNameParser::OptionInfo *vT1,
const RewriterNameParser::OptionInfo *vT2) {
return vT1->Name.compare(vT2->Name);
});
using llvm::cl::parser;
parser<const RewriterInfo *>::printOptionInfo(o, globalWidth);
}
} // namespace mlir
// TODO: Make these injectable too in non-global way.
static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"};
static llvm::cl::opt<std::string> simpleRenameOpName{
"simple-rename-op-name", llvm::cl::desc("Name of op to match on"),
llvm::cl::cat(clSimpleRenameCategory)};
static llvm::cl::opt<std::string> simpleRenameMatch{
"simple-rename-match", llvm::cl::desc("Match string for rename"),
llvm::cl::cat(clSimpleRenameCategory)};
static llvm::cl::opt<std::string> simpleRenameReplace{
"simple-rename-replace", llvm::cl::desc("Replace string for rename"),
llvm::cl::cat(clSimpleRenameCategory)};
// Rewriter that does simple renames.
LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) {
StringRef opName = simpleRenameOpName;
StringRef match = simpleRenameMatch;
StringRef replace = simpleRenameReplace;
llvm::Regex regex(match);
rewriteState.getParsed()->walk([&](Operation *op) {
if (op->getName().getStringRef() != opName)
return;
const OperationDefinition &opDef = rewriteState.getOpDef(op);
SMRange range = getOpRange(opDef);
// This is a little bit overkill for simple.
std::string str = regex.sub(replace, rewriteState.getSourceString(range));
rewriteState.replaceRange(range, str);
});
return success();
}
static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
"Perform a simple rename",
simpleRename);
// Rewriter that insert range markers.
LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) {
for (const auto &it : rewriteState.getOpDefs()) {
auto [startOp, endOp] = getOpRange(it);
rewriteState.insertText(startOp, "<");
rewriteState.insertText(endOp, ">");
auto nameRange = getOpNameRange(it);
if (isGeneric(it)) {
rewriteState.insertText(nameRange.Start, "[");
rewriteState.insertText(nameRange.End, "]");
} else {
rewriteState.insertText(nameRange.Start, "![");
rewriteState.insertText(nameRange.End, "]!");
}
}
// Highlight all comment lines.
// TODO: Could be replaced if this is kept in memory.
for (auto commentLine : rewriteState.getSingleLineComments()) {
rewriteState.insertText(commentLine.Start, "{");
rewriteState.insertText(commentLine.End, "}");
}
return success();
}
static mlir::RewriterRegistration
rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges);
int main(int argc, char **argv) {
llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser>
rewriter("", llvm::cl::desc("Rewriter to run"));
std::string helpHeader = "mlir-rewrite";
llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader);
// If no rewriter has been selected, exit with error code. Could also just
// return but its unlikely this was intentionally being used as `cp`.
if (!rewriter) {
llvm::errs() << "No rewriter selected!\n";
return mlir::asMainReturnCode(mlir::failure());
}
// Set up rewrite buffer.
auto rewriterOr = RewritePad::init(inputFilename, outputFilename);
if (!rewriterOr)
return mlir::asMainReturnCode(mlir::failure());
// Set up the output file.
std::string errorMessage;
auto output = openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
return mlir::asMainReturnCode(mlir::failure());
}
LogicalResult result = rewriter->invoke(*rewriterOr, output->os());
if (succeeded(result)) {
rewriterOr->write(output->os());
output->keep();
}
return mlir::asMainReturnCode(result);
}