338 lines
10 KiB
C++
338 lines
10 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// The analysis determines the convergence region for each basic block of
|
|
// the module, and provides a tree-like structure describing the region
|
|
// hierarchy.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "SPIRVConvergenceRegionAnalysis.h"
|
|
#include "SPIRV.h"
|
|
#include "llvm/Analysis/LoopInfo.h"
|
|
#include "llvm/IR/Dominators.h"
|
|
#include "llvm/IR/IntrinsicInst.h"
|
|
#include "llvm/InitializePasses.h"
|
|
#include "llvm/Transforms/Utils/LoopSimplify.h"
|
|
#include <optional>
|
|
#include <queue>
|
|
|
|
#define DEBUG_TYPE "spirv-convergence-region-analysis"
|
|
|
|
using namespace llvm;
|
|
using namespace SPIRV;
|
|
|
|
INITIALIZE_PASS_BEGIN(SPIRVConvergenceRegionAnalysisWrapperPass,
|
|
"convergence-region",
|
|
"SPIRV convergence regions analysis", true, true)
|
|
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
|
|
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
|
|
INITIALIZE_PASS_END(SPIRVConvergenceRegionAnalysisWrapperPass,
|
|
"convergence-region", "SPIRV convergence regions analysis",
|
|
true, true)
|
|
|
|
namespace {
|
|
|
|
template <typename BasicBlockType, typename IntrinsicInstType>
|
|
std::optional<IntrinsicInstType *>
|
|
getConvergenceTokenInternal(BasicBlockType *BB) {
|
|
static_assert(std::is_const_v<IntrinsicInstType> ==
|
|
std::is_const_v<BasicBlockType>,
|
|
"Constness must match between input and output.");
|
|
static_assert(std::is_same_v<BasicBlock, std::remove_const_t<BasicBlockType>>,
|
|
"Input must be a basic block.");
|
|
static_assert(
|
|
std::is_same_v<IntrinsicInst, std::remove_const_t<IntrinsicInstType>>,
|
|
"Output type must be an intrinsic instruction.");
|
|
|
|
for (auto &I : *BB) {
|
|
if (auto *CI = dyn_cast<ConvergenceControlInst>(&I)) {
|
|
// Make sure that the anchor or entry intrinsics did not reach here with a
|
|
// parent token. This should have failed the verifier.
|
|
assert(CI->isLoop() ||
|
|
!CI->getOperandBundle(LLVMContext::OB_convergencectrl));
|
|
return CI;
|
|
}
|
|
|
|
if (auto *CI = dyn_cast<CallInst>(&I)) {
|
|
auto OB = CI->getOperandBundle(LLVMContext::OB_convergencectrl);
|
|
if (!OB.has_value())
|
|
continue;
|
|
return dyn_cast<IntrinsicInst>(OB.value().Inputs[0]);
|
|
}
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
} // anonymous namespace
|
|
|
|
// Given a ConvergenceRegion tree with |Start| as its root, finds the smallest
|
|
// region |Entry| belongs to. If |Entry| does not belong to the region defined
|
|
// by |Start|, this function returns |nullptr|.
|
|
static ConvergenceRegion *findParentRegion(ConvergenceRegion *Start,
|
|
BasicBlock *Entry) {
|
|
ConvergenceRegion *Candidate = nullptr;
|
|
ConvergenceRegion *NextCandidate = Start;
|
|
|
|
while (Candidate != NextCandidate && NextCandidate != nullptr) {
|
|
Candidate = NextCandidate;
|
|
NextCandidate = nullptr;
|
|
|
|
// End of the search, we can return.
|
|
if (Candidate->Children.size() == 0)
|
|
return Candidate;
|
|
|
|
for (auto *Child : Candidate->Children) {
|
|
if (Child->Blocks.count(Entry) != 0) {
|
|
NextCandidate = Child;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
return Candidate;
|
|
}
|
|
|
|
std::optional<IntrinsicInst *>
|
|
llvm::SPIRV::getConvergenceToken(BasicBlock *BB) {
|
|
return getConvergenceTokenInternal<BasicBlock, IntrinsicInst>(BB);
|
|
}
|
|
|
|
std::optional<const IntrinsicInst *>
|
|
llvm::SPIRV::getConvergenceToken(const BasicBlock *BB) {
|
|
return getConvergenceTokenInternal<const BasicBlock, const IntrinsicInst>(BB);
|
|
}
|
|
|
|
ConvergenceRegion::ConvergenceRegion(DominatorTree &DT, LoopInfo &LI,
|
|
Function &F)
|
|
: DT(DT), LI(LI), Parent(nullptr) {
|
|
Entry = &F.getEntryBlock();
|
|
ConvergenceToken = getConvergenceToken(Entry);
|
|
for (auto &B : F) {
|
|
Blocks.insert(&B);
|
|
if (isa<ReturnInst>(B.getTerminator()))
|
|
Exits.insert(&B);
|
|
}
|
|
}
|
|
|
|
ConvergenceRegion::ConvergenceRegion(
|
|
DominatorTree &DT, LoopInfo &LI,
|
|
std::optional<IntrinsicInst *> ConvergenceToken, BasicBlock *Entry,
|
|
SmallPtrSet<BasicBlock *, 8> &&Blocks, SmallPtrSet<BasicBlock *, 2> &&Exits)
|
|
: DT(DT), LI(LI), ConvergenceToken(ConvergenceToken), Entry(Entry),
|
|
Exits(std::move(Exits)), Blocks(std::move(Blocks)) {
|
|
for ([[maybe_unused]] auto *BB : this->Exits)
|
|
assert(this->Blocks.count(BB) != 0);
|
|
assert(this->Blocks.count(this->Entry) != 0);
|
|
}
|
|
|
|
void ConvergenceRegion::releaseMemory() {
|
|
// Parent memory is owned by the parent.
|
|
Parent = nullptr;
|
|
for (auto *Child : Children) {
|
|
Child->releaseMemory();
|
|
delete Child;
|
|
}
|
|
Children.resize(0);
|
|
}
|
|
|
|
void ConvergenceRegion::dump(const unsigned IndentSize) const {
|
|
const std::string Indent(IndentSize, '\t');
|
|
dbgs() << Indent << this << ": {\n";
|
|
dbgs() << Indent << " Parent: " << Parent << "\n";
|
|
|
|
if (ConvergenceToken.value_or(nullptr)) {
|
|
dbgs() << Indent
|
|
<< " ConvergenceToken: " << ConvergenceToken.value()->getName()
|
|
<< "\n";
|
|
}
|
|
|
|
if (Entry->getName() != "")
|
|
dbgs() << Indent << " Entry: " << Entry->getName() << "\n";
|
|
else
|
|
dbgs() << Indent << " Entry: " << Entry << "\n";
|
|
|
|
dbgs() << Indent << " Exits: { ";
|
|
for (const auto &Exit : Exits) {
|
|
if (Exit->getName() != "")
|
|
dbgs() << Exit->getName() << ", ";
|
|
else
|
|
dbgs() << Exit << ", ";
|
|
}
|
|
dbgs() << " }\n";
|
|
|
|
dbgs() << Indent << " Blocks: { ";
|
|
for (const auto &Block : Blocks) {
|
|
if (Block->getName() != "")
|
|
dbgs() << Block->getName() << ", ";
|
|
else
|
|
dbgs() << Block << ", ";
|
|
}
|
|
dbgs() << " }\n";
|
|
|
|
dbgs() << Indent << " Children: {\n";
|
|
for (const auto Child : Children)
|
|
Child->dump(IndentSize + 2);
|
|
dbgs() << Indent << " }\n";
|
|
|
|
dbgs() << Indent << "}\n";
|
|
}
|
|
|
|
namespace {
|
|
class ConvergenceRegionAnalyzer {
|
|
public:
|
|
ConvergenceRegionAnalyzer(Function &F, DominatorTree &DT, LoopInfo &LI)
|
|
: DT(DT), LI(LI), F(F) {}
|
|
|
|
private:
|
|
bool isBackEdge(const BasicBlock *From, const BasicBlock *To) const {
|
|
if (From == To)
|
|
return true;
|
|
|
|
// We only handle loop in the simplified form. This means:
|
|
// - a single back-edge, a single latch.
|
|
// - meaning the back-edge target can only be the loop header.
|
|
// - meaning the From can only be the loop latch.
|
|
if (!LI.isLoopHeader(To))
|
|
return false;
|
|
|
|
auto *L = LI.getLoopFor(To);
|
|
if (L->contains(From) && L->isLoopLatch(From))
|
|
return true;
|
|
|
|
return false;
|
|
}
|
|
|
|
std::unordered_set<BasicBlock *>
|
|
findPathsToMatch(LoopInfo &LI, BasicBlock *From,
|
|
std::function<bool(const BasicBlock *)> isMatch) const {
|
|
std::unordered_set<BasicBlock *> Output;
|
|
|
|
if (isMatch(From))
|
|
Output.insert(From);
|
|
|
|
auto *Terminator = From->getTerminator();
|
|
for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
|
|
auto *To = Terminator->getSuccessor(i);
|
|
// Ignore back edges.
|
|
if (isBackEdge(From, To))
|
|
continue;
|
|
|
|
auto ChildSet = findPathsToMatch(LI, To, isMatch);
|
|
if (ChildSet.size() == 0)
|
|
continue;
|
|
|
|
Output.insert(ChildSet.begin(), ChildSet.end());
|
|
Output.insert(From);
|
|
if (LI.isLoopHeader(From)) {
|
|
auto *L = LI.getLoopFor(From);
|
|
for (auto *BB : L->getBlocks()) {
|
|
Output.insert(BB);
|
|
}
|
|
}
|
|
}
|
|
|
|
return Output;
|
|
}
|
|
|
|
SmallPtrSet<BasicBlock *, 2>
|
|
findExitNodes(const SmallPtrSetImpl<BasicBlock *> &RegionBlocks) {
|
|
SmallPtrSet<BasicBlock *, 2> Exits;
|
|
|
|
for (auto *B : RegionBlocks) {
|
|
auto *Terminator = B->getTerminator();
|
|
for (unsigned i = 0; i < Terminator->getNumSuccessors(); ++i) {
|
|
auto *Child = Terminator->getSuccessor(i);
|
|
if (RegionBlocks.count(Child) == 0)
|
|
Exits.insert(B);
|
|
}
|
|
}
|
|
|
|
return Exits;
|
|
}
|
|
|
|
public:
|
|
ConvergenceRegionInfo analyze() {
|
|
ConvergenceRegion *TopLevelRegion = new ConvergenceRegion(DT, LI, F);
|
|
std::queue<Loop *> ToProcess;
|
|
for (auto *L : LI.getLoopsInPreorder())
|
|
ToProcess.push(L);
|
|
|
|
while (ToProcess.size() != 0) {
|
|
auto *L = ToProcess.front();
|
|
ToProcess.pop();
|
|
|
|
auto CT = getConvergenceToken(L->getHeader());
|
|
SmallPtrSet<BasicBlock *, 8> RegionBlocks(llvm::from_range, L->blocks());
|
|
SmallVector<BasicBlock *> LoopExits;
|
|
L->getExitingBlocks(LoopExits);
|
|
if (CT.has_value()) {
|
|
for (auto *Exit : LoopExits) {
|
|
auto N = findPathsToMatch(LI, Exit, [&CT](const BasicBlock *block) {
|
|
auto Token = getConvergenceToken(block);
|
|
if (Token == std::nullopt)
|
|
return false;
|
|
return Token.value() == CT.value();
|
|
});
|
|
RegionBlocks.insert_range(N);
|
|
}
|
|
}
|
|
|
|
auto RegionExits = findExitNodes(RegionBlocks);
|
|
ConvergenceRegion *Region = new ConvergenceRegion(
|
|
DT, LI, CT, L->getHeader(), std::move(RegionBlocks),
|
|
std::move(RegionExits));
|
|
Region->Parent = findParentRegion(TopLevelRegion, Region->Entry);
|
|
assert(Region->Parent != nullptr && "This is impossible.");
|
|
Region->Parent->Children.push_back(Region);
|
|
}
|
|
|
|
return ConvergenceRegionInfo(TopLevelRegion);
|
|
}
|
|
|
|
private:
|
|
DominatorTree &DT;
|
|
LoopInfo &LI;
|
|
Function &F;
|
|
};
|
|
} // anonymous namespace
|
|
|
|
ConvergenceRegionInfo llvm::SPIRV::getConvergenceRegions(Function &F,
|
|
DominatorTree &DT,
|
|
LoopInfo &LI) {
|
|
ConvergenceRegionAnalyzer Analyzer(F, DT, LI);
|
|
return Analyzer.analyze();
|
|
}
|
|
|
|
char SPIRVConvergenceRegionAnalysisWrapperPass::ID = 0;
|
|
|
|
SPIRVConvergenceRegionAnalysisWrapperPass::
|
|
SPIRVConvergenceRegionAnalysisWrapperPass()
|
|
: FunctionPass(ID) {}
|
|
|
|
bool SPIRVConvergenceRegionAnalysisWrapperPass::runOnFunction(Function &F) {
|
|
DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
|
LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
|
|
|
|
CRI = SPIRV::getConvergenceRegions(F, DT, LI);
|
|
// Nothing was modified.
|
|
return false;
|
|
}
|
|
|
|
SPIRVConvergenceRegionAnalysis::Result
|
|
SPIRVConvergenceRegionAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
|
|
Result CRI;
|
|
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
|
|
auto &LI = AM.getResult<LoopAnalysis>(F);
|
|
CRI = SPIRV::getConvergenceRegions(F, DT, LI);
|
|
return CRI;
|
|
}
|
|
|
|
AnalysisKey SPIRVConvergenceRegionAnalysis::Key;
|