191 lines
5.4 KiB
C++
191 lines
5.4 KiB
C++
//===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
|
|
#include "llvm/AsmParser/Parser.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/Passes/PassBuilder.h"
|
|
#include "llvm/ProfileData/InstrProf.h"
|
|
|
|
#include "gmock/gmock.h"
|
|
#include "gtest/gtest.h"
|
|
|
|
#include <tuple>
|
|
|
|
namespace {
|
|
|
|
using namespace llvm;
|
|
|
|
using testing::_;
|
|
using ::testing::DoDefault;
|
|
using ::testing::Invoke;
|
|
using ::testing::NotNull;
|
|
using ::testing::Ref;
|
|
using ::testing::Return;
|
|
using ::testing::Sequence;
|
|
using ::testing::Test;
|
|
using ::testing::TestParamInfo;
|
|
using ::testing::Values;
|
|
using ::testing::WithParamInterface;
|
|
|
|
template <typename Derived> class MockAnalysisHandleBase {
|
|
public:
|
|
class Analysis : public AnalysisInfoMixin<Analysis> {
|
|
public:
|
|
class Result {
|
|
public:
|
|
// Forward invalidation events to the mock handle.
|
|
bool invalidate(Module &M, const PreservedAnalyses &PA,
|
|
ModuleAnalysisManager::Invalidator &Inv) {
|
|
return Handle->invalidate(M, PA, Inv);
|
|
}
|
|
|
|
private:
|
|
explicit Result(Derived *Handle) : Handle(Handle) {}
|
|
|
|
friend MockAnalysisHandleBase;
|
|
Derived *Handle;
|
|
};
|
|
|
|
Result run(Module &M, ModuleAnalysisManager &AM) {
|
|
return Handle->run(M, AM);
|
|
}
|
|
|
|
private:
|
|
friend AnalysisInfoMixin<Analysis>;
|
|
friend MockAnalysisHandleBase;
|
|
static inline AnalysisKey Key;
|
|
|
|
Derived *Handle;
|
|
|
|
explicit Analysis(Derived *Handle) : Handle(Handle) {}
|
|
};
|
|
|
|
Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }
|
|
|
|
typename Analysis::Result getResult() {
|
|
return typename Analysis::Result(static_cast<Derived *>(this));
|
|
}
|
|
|
|
protected:
|
|
void setDefaults() {
|
|
ON_CALL(static_cast<Derived &>(*this), run(_, _))
|
|
.WillByDefault(Return(this->getResult()));
|
|
ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
|
|
.WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
|
|
ModuleAnalysisManager::Invalidator &) {
|
|
auto PAC = PA.template getChecker<Analysis>();
|
|
return !PAC.preserved() &&
|
|
!PAC.template preservedSet<AllAnalysesOn<Module>>();
|
|
}));
|
|
}
|
|
|
|
private:
|
|
friend Derived;
|
|
MockAnalysisHandleBase() = default;
|
|
};
|
|
|
|
class MockModuleAnalysisHandle
|
|
: public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
|
|
public:
|
|
MockModuleAnalysisHandle() { setDefaults(); }
|
|
|
|
MOCK_METHOD(typename Analysis::Result, run,
|
|
(Module &, ModuleAnalysisManager &));
|
|
|
|
MOCK_METHOD(bool, invalidate,
|
|
(Module &, const PreservedAnalyses &,
|
|
ModuleAnalysisManager::Invalidator &));
|
|
};
|
|
|
|
struct PGOInstrumentationGenTest
|
|
: public Test,
|
|
WithParamInterface<std::tuple<StringRef, StringRef>> {
|
|
ModulePassManager MPM;
|
|
PassBuilder PB;
|
|
MockModuleAnalysisHandle MMAHandle;
|
|
LoopAnalysisManager LAM;
|
|
FunctionAnalysisManager FAM;
|
|
CGSCCAnalysisManager CGAM;
|
|
ModuleAnalysisManager MAM;
|
|
LLVMContext Context;
|
|
std::unique_ptr<Module> M;
|
|
|
|
PGOInstrumentationGenTest() {
|
|
MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
|
|
PB.registerModuleAnalyses(MAM);
|
|
PB.registerCGSCCAnalyses(CGAM);
|
|
PB.registerFunctionAnalyses(FAM);
|
|
PB.registerLoopAnalyses(LAM);
|
|
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
|
|
MPM.addPass(
|
|
RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
|
|
MPM.addPass(PGOInstrumentationGen());
|
|
}
|
|
|
|
void parseAssembly(const StringRef IR) {
|
|
SMDiagnostic Error;
|
|
M = parseAssemblyString(IR, Error, Context);
|
|
std::string ErrMsg;
|
|
raw_string_ostream OS(ErrMsg);
|
|
Error.print("", OS);
|
|
|
|
// A failure here means that the test itself is buggy.
|
|
if (!M)
|
|
report_fatal_error(ErrMsg.c_str());
|
|
}
|
|
};
|
|
|
|
static constexpr StringRef CodeWithFuncDefs = R"(
|
|
define i32 @f(i32 %n) {
|
|
entry:
|
|
ret i32 0
|
|
})";
|
|
|
|
static constexpr StringRef CodeWithFuncDecls = R"(
|
|
declare i32 @f(i32);
|
|
)";
|
|
|
|
static constexpr StringRef CodeWithGlobals = R"(
|
|
@foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
|
|
declare i32 @f(i32);
|
|
)";
|
|
|
|
INSTANTIATE_TEST_SUITE_P(
|
|
PGOInstrumetationGenTestSuite, PGOInstrumentationGenTest,
|
|
Values(std::make_tuple(CodeWithFuncDefs, "instrument_function_defs"),
|
|
std::make_tuple(CodeWithFuncDecls, "instrument_function_decls"),
|
|
std::make_tuple(CodeWithGlobals, "instrument_globals")),
|
|
[](const TestParamInfo<PGOInstrumentationGenTest::ParamType> &Info) {
|
|
return std::get<1>(Info.param).str();
|
|
});
|
|
|
|
TEST_P(PGOInstrumentationGenTest, Instrumented) {
|
|
const StringRef Code = std::get<0>(GetParam());
|
|
parseAssembly(Code);
|
|
|
|
ASSERT_THAT(M, NotNull());
|
|
|
|
Sequence PassSequence;
|
|
EXPECT_CALL(MMAHandle, run(Ref(*M), _))
|
|
.InSequence(PassSequence)
|
|
.WillOnce(DoDefault());
|
|
EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _))
|
|
.InSequence(PassSequence)
|
|
.WillOnce(DoDefault());
|
|
|
|
MPM.run(*M, MAM);
|
|
|
|
const auto *IRInstrVar =
|
|
M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
|
|
EXPECT_THAT(IRInstrVar, NotNull());
|
|
EXPECT_FALSE(IRInstrVar->isDeclaration());
|
|
}
|
|
|
|
} // end anonymous namespace
|