
Don't create a fix where object invoked on is a temporary object as create method requires a reference.
140 lines
5.5 KiB
C++
140 lines
5.5 KiB
C++
//===--- UseNewMLIROpBuilderCheck.cpp - clang-tidy ------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "UseNewMLIROpBuilderCheck.h"
|
|
#include "clang/ASTMatchers/ASTMatchers.h"
|
|
#include "clang/Basic/LLVM.h"
|
|
#include "clang/Lex/Lexer.h"
|
|
#include "clang/Tooling/Transformer/RangeSelector.h"
|
|
#include "clang/Tooling/Transformer/RewriteRule.h"
|
|
#include "clang/Tooling/Transformer/SourceCode.h"
|
|
#include "clang/Tooling/Transformer/Stencil.h"
|
|
#include "llvm/Support/Error.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
namespace clang::tidy::llvm_check {
|
|
namespace {
|
|
|
|
using namespace ::clang::ast_matchers;
|
|
using namespace ::clang::transformer;
|
|
|
|
EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
|
|
RangeSelector CallArgs) {
|
|
// This is using an EditGenerator rather than ASTEdit as we want to warn even
|
|
// if in macro.
|
|
return [Call = std::move(Call), Builder = std::move(Builder),
|
|
CallArgs =
|
|
std::move(CallArgs)](const MatchFinder::MatchResult &Result)
|
|
-> Expected<SmallVector<transformer::Edit, 1>> {
|
|
Expected<CharSourceRange> CallRange = Call(Result);
|
|
if (!CallRange)
|
|
return CallRange.takeError();
|
|
SourceManager &SM = *Result.SourceManager;
|
|
const LangOptions &LangOpts = Result.Context->getLangOpts();
|
|
SourceLocation Begin = CallRange->getBegin();
|
|
|
|
// This will result in just a warning and no edit.
|
|
bool InMacro = CallRange->getBegin().isMacroID();
|
|
if (InMacro) {
|
|
while (SM.isMacroArgExpansion(Begin))
|
|
Begin = SM.getImmediateExpansionRange(Begin).getBegin();
|
|
Edit WarnOnly;
|
|
WarnOnly.Kind = EditKind::Range;
|
|
WarnOnly.Range = CharSourceRange::getCharRange(Begin, Begin);
|
|
return SmallVector<Edit, 1>({WarnOnly});
|
|
}
|
|
|
|
// This will try to extract the template argument as written so that the
|
|
// rewritten code looks closest to original.
|
|
auto NextToken = [&](std::optional<Token> CurrentToken) {
|
|
if (!CurrentToken)
|
|
return CurrentToken;
|
|
if (CurrentToken->getEndLoc() >= CallRange->getEnd())
|
|
return std::optional<Token>();
|
|
return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
|
|
LangOpts);
|
|
};
|
|
std::optional<Token> LessToken =
|
|
clang::Lexer::findNextToken(Begin, SM, LangOpts);
|
|
while (LessToken && LessToken->getKind() != clang::tok::less) {
|
|
LessToken = NextToken(LessToken);
|
|
}
|
|
if (!LessToken) {
|
|
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
|
|
"missing '<' token");
|
|
}
|
|
std::optional<Token> EndToken = NextToken(LessToken);
|
|
for (std::optional<Token> GreaterToken = NextToken(EndToken);
|
|
GreaterToken && GreaterToken->getKind() != clang::tok::greater;
|
|
GreaterToken = NextToken(GreaterToken)) {
|
|
EndToken = GreaterToken;
|
|
}
|
|
if (!EndToken) {
|
|
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
|
|
"missing '>' token");
|
|
}
|
|
|
|
Expected<CharSourceRange> BuilderRange = Builder(Result);
|
|
if (!BuilderRange)
|
|
return BuilderRange.takeError();
|
|
Expected<CharSourceRange> CallArgsRange = CallArgs(Result);
|
|
if (!CallArgsRange)
|
|
return CallArgsRange.takeError();
|
|
|
|
// Helper for concatting below.
|
|
auto GetText = [&](const CharSourceRange &Range) {
|
|
return clang::Lexer::getSourceText(Range, SM, LangOpts);
|
|
};
|
|
|
|
Edit Replace;
|
|
Replace.Kind = EditKind::Range;
|
|
Replace.Range = *CallRange;
|
|
std::string CallArgsStr;
|
|
// Only emit args if there are any.
|
|
if (auto CallArgsText = GetText(*CallArgsRange).ltrim();
|
|
!CallArgsText.rtrim().empty()) {
|
|
CallArgsStr = llvm::formatv(", {}", CallArgsText);
|
|
}
|
|
Replace.Replacement =
|
|
llvm::formatv("{}::create({}{})",
|
|
GetText(CharSourceRange::getTokenRange(
|
|
LessToken->getEndLoc(), EndToken->getLastLoc())),
|
|
GetText(*BuilderRange), CallArgsStr);
|
|
|
|
return SmallVector<Edit, 1>({Replace});
|
|
};
|
|
}
|
|
|
|
RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
|
|
Stencil message = cat("use 'OpType::create(builder, ...)' instead of "
|
|
"'builder.create<OpType>(...)'");
|
|
// Match a create call on an OpBuilder.
|
|
ast_matchers::internal::Matcher<Stmt> base =
|
|
cxxMemberCallExpr(
|
|
on(expr(hasType(
|
|
cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"))))
|
|
.bind("builder")),
|
|
callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))),
|
|
callee(cxxMethodDecl(hasName("create"))))
|
|
.bind("call");
|
|
return applyFirst(
|
|
// Attempt rewrite given an lvalue builder, else just warn.
|
|
{makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), base),
|
|
rewrite(node("call"), node("builder"), callArgs("call")),
|
|
message),
|
|
makeRule(base, noopEdit(node("call")), message)});
|
|
}
|
|
} // namespace
|
|
|
|
UseNewMlirOpBuilderCheck::UseNewMlirOpBuilderCheck(StringRef Name,
|
|
ClangTidyContext *Context)
|
|
: TransformerClangTidyCheck(useNewMlirOpBuilderCheckRule(), Name, Context) {
|
|
}
|
|
|
|
} // namespace clang::tidy::llvm_check
|