See: https://github.com/llvm/llvm-project/issues/131779 Extends uniformity analysis to support instructions whose uniformity depends on which specific operands are uniform. Introduces `InstructionUniformity::Custom` and a target hook `TTI::isUniform(I, UniformArgs)` that allows targets to define custom uniformity rules. During propagation, custom candidates are checked via the target hook. If we can prove they are uniform, we skip marking them divergent and let iterative propagation re-evaluate as operands change. Implements AMDGPU's `llvm.amdgcn.wave.shuffle` rules (uniform when either operand is uniform, divergent only when both are divergent) as the motivating example. This inverted-logic approach is critical for correctness: proving uniformity early during propagation would be unsafe, as operands can transition from uniform to divergent during divergence propagation. --------- Co-authored-by: Matt Arsenault <arsenm2@gmail.com>
208 lines
6.8 KiB
C++
208 lines
6.8 KiB
C++
//===- UniformityAnalysis.cpp ---------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Analysis/UniformityAnalysis.h"
|
|
#include "llvm/ADT/GenericUniformityImpl.h"
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
#include "llvm/Analysis/CycleAnalysis.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/IR/Dominators.h"
|
|
#include "llvm/IR/InstIterator.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/InitializePasses.h"
|
|
|
|
using namespace llvm;
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
|
|
const Instruction &I) const {
|
|
return isDivergent((const Value *)&I);
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
|
|
const Instruction &Instr) {
|
|
return markDivergent(cast<Value>(&Instr));
|
|
}
|
|
|
|
template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
|
|
for (auto &I : instructions(F)) {
|
|
InstructionUniformity IU = TTI->getInstructionUniformity(&I);
|
|
switch (IU) {
|
|
case InstructionUniformity::AlwaysUniform:
|
|
addUniformOverride(I);
|
|
break;
|
|
case InstructionUniformity::NeverUniform:
|
|
markDivergent(I);
|
|
break;
|
|
case InstructionUniformity::Custom:
|
|
addCustomUniformityCandidate(&I);
|
|
break;
|
|
case InstructionUniformity::Default:
|
|
break;
|
|
}
|
|
}
|
|
for (auto &Arg : F.args()) {
|
|
if (TTI->getInstructionUniformity(&Arg) ==
|
|
InstructionUniformity::NeverUniform)
|
|
markDivergent(&Arg);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
|
|
const Value *V) {
|
|
for (const auto *User : V->users()) {
|
|
if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
|
|
markDivergent(*UserInstr);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
|
|
const Instruction &Instr) {
|
|
assert(!isAlwaysUniform(Instr));
|
|
if (Instr.isTerminator())
|
|
return;
|
|
pushUsers(cast<Value>(&Instr));
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
|
|
const Instruction &I, const Cycle &DefCycle) const {
|
|
assert(!isAlwaysUniform(I));
|
|
for (const Use &U : I.operands()) {
|
|
if (auto *I = dyn_cast<Instruction>(&U)) {
|
|
if (DefCycle.contains(I->getParent()))
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
void llvm::GenericUniformityAnalysisImpl<
|
|
SSAContext>::propagateTemporalDivergence(const Instruction &I,
|
|
const Cycle &DefCycle) {
|
|
for (auto *User : I.users()) {
|
|
auto *UserInstr = cast<Instruction>(User);
|
|
if (DefCycle.contains(UserInstr->getParent()))
|
|
continue;
|
|
markDivergent(*UserInstr);
|
|
recordTemporalDivergence(&I, UserInstr, &DefCycle);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
|
|
const Use &U) const {
|
|
const auto *V = U.get();
|
|
if (isDivergent(V))
|
|
return true;
|
|
if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
|
|
const auto *UseInstr = cast<Instruction>(U.getUser());
|
|
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <>
|
|
bool GenericUniformityAnalysisImpl<SSAContext>::isCustomUniform(
|
|
const Instruction &I) const {
|
|
SmallBitVector UniformArgs(I.getNumOperands());
|
|
for (auto [Idx, Use] : enumerate(I.operands()))
|
|
UniformArgs[Idx] = !isDivergentUse(Use);
|
|
return TTI->isUniform(&I, UniformArgs);
|
|
}
|
|
|
|
// This ensures explicit instantiation of
|
|
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
|
|
template class llvm::GenericUniformityInfo<SSAContext>;
|
|
template struct llvm::GenericUniformityAnalysisImplDeleter<
|
|
llvm::GenericUniformityAnalysisImpl<SSAContext>>;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UniformityInfoAnalysis and related pass implementations
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
|
|
FunctionAnalysisManager &FAM) {
|
|
auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
|
|
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
|
|
auto &CI = FAM.getResult<CycleAnalysis>(F);
|
|
UniformityInfo UI{DT, CI, &TTI};
|
|
// Skip computation if we can assume everything is uniform.
|
|
if (TTI.hasBranchDivergence(&F))
|
|
UI.compute();
|
|
|
|
return UI;
|
|
}
|
|
|
|
AnalysisKey UniformityInfoAnalysis::Key;
|
|
|
|
UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
|
|
: OS(OS) {}
|
|
|
|
PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
|
|
FunctionAnalysisManager &AM) {
|
|
OS << "UniformityInfo for function '" << F.getName() << "':\n";
|
|
AM.getResult<UniformityInfoAnalysis>(F).print(OS);
|
|
|
|
return PreservedAnalyses::all();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// UniformityInfoWrapperPass Implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
char UniformityInfoWrapperPass::ID = 0;
|
|
|
|
UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
|
|
|
|
INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
|
|
"Uniformity Analysis", false, true)
|
|
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
|
|
INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
|
|
"Uniformity Analysis", false, true)
|
|
|
|
void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
|
|
AU.setPreservesAll();
|
|
AU.addRequired<DominatorTreeWrapperPass>();
|
|
AU.addRequiredTransitive<CycleInfoWrapperPass>();
|
|
AU.addRequired<TargetTransformInfoWrapperPass>();
|
|
}
|
|
|
|
bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
|
|
auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
|
|
auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
|
auto &targetTransformInfo =
|
|
getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
|
|
|
|
m_function = &F;
|
|
m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
|
|
|
|
// Skip computation if we can assume everything is uniform.
|
|
if (targetTransformInfo.hasBranchDivergence(m_function))
|
|
m_uniformityInfo.compute();
|
|
|
|
return false;
|
|
}
|
|
|
|
void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
|
|
OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
|
|
m_uniformityInfo.print(OS);
|
|
}
|
|
|
|
void UniformityInfoWrapperPass::releaseMemory() {
|
|
m_uniformityInfo = UniformityInfo{};
|
|
m_function = nullptr;
|
|
}
|