llvm-project/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
Pankaj Dwivedi 6e2a720dd1
[AMDGPU][Uniformity][TTI] Make Uniformity Analysis Operand-Aware via Custom Uniformity Checks (#137639)
See: https://github.com/llvm/llvm-project/issues/131779

Extends uniformity analysis to support instructions whose uniformity
depends on which specific operands are uniform. Introduces
`InstructionUniformity::Custom` and a target hook `TTI::isUniform(I,
UniformArgs)` that allows targets to define custom uniformity rules.
During propagation, custom candidates are checked via the target hook.
If we can prove they are uniform, we skip marking them divergent and let
iterative propagation re-evaluate as operands change.

Implements AMDGPU's `llvm.amdgcn.wave.shuffle` rules (uniform when
either operand is uniform, divergent only when both are divergent) as
the motivating example.

This inverted-logic approach is critical for correctness: proving
uniformity early during propagation would be unsafe, as operands can
transition from uniform to divergent during divergence propagation.

---------

Co-authored-by: Matt Arsenault <arsenm2@gmail.com>
2026-03-24 22:24:57 +05:30

280 lines
8.9 KiB
C++

//===- MachineUniformityAnalysis.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/CodeGen/MachineUniformityAnalysis.h"
#include "llvm/ADT/GenericUniformityImpl.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/MachineCycleAnalysis.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/MachineSSAContext.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/InitializePasses.h"
using namespace llvm;
template <>
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
const MachineInstr &I) const {
for (auto &op : I.all_defs()) {
if (isDivergent(op.getReg()))
return true;
}
return false;
}
template <>
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
const MachineInstr &Instr) {
bool insertedDivergent = false;
const auto &MRI = F.getRegInfo();
const auto &RBI = *F.getSubtarget().getRegBankInfo();
const auto &TRI = *MRI.getTargetRegisterInfo();
for (auto &op : Instr.all_defs()) {
if (!op.getReg().isVirtual())
continue;
assert(!op.getSubReg());
if (TRI.isUniformReg(MRI, RBI, op.getReg()))
continue;
insertedDivergent |= markDivergent(op.getReg());
}
return insertedDivergent;
}
template <>
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
for (const MachineBasicBlock &block : F) {
for (const MachineInstr &instr : block) {
auto uniformity = InstrInfo.getInstructionUniformity(instr);
switch (uniformity) {
case InstructionUniformity::AlwaysUniform:
addUniformOverride(instr);
break;
case InstructionUniformity::NeverUniform:
markDivergent(instr);
break;
case InstructionUniformity::Custom:
break;
case InstructionUniformity::Default:
break;
}
}
}
}
template <>
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
Register Reg) {
assert(isDivergent(Reg));
const auto &RegInfo = F.getRegInfo();
for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
markDivergent(UserInstr);
}
}
template <>
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
const MachineInstr &Instr) {
assert(!isAlwaysUniform(Instr));
if (Instr.isTerminator())
return;
for (const MachineOperand &op : Instr.all_defs()) {
auto Reg = op.getReg();
if (isDivergent(Reg))
pushUsers(Reg);
}
}
template <>
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
const MachineInstr &I, const MachineCycle &DefCycle) const {
assert(!isAlwaysUniform(I));
for (auto &Op : I.operands()) {
if (!Op.isReg() || !Op.readsReg())
continue;
auto Reg = Op.getReg();
// FIXME: Physical registers need to be properly checked instead of always
// returning true
if (Reg.isPhysical())
return true;
auto *Def = F.getRegInfo().getVRegDef(Reg);
if (DefCycle.contains(Def->getParent()))
return true;
}
return false;
}
template <>
void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
propagateTemporalDivergence(const MachineInstr &I,
const MachineCycle &DefCycle) {
const auto &RegInfo = F.getRegInfo();
for (auto &Op : I.all_defs()) {
if (!Op.getReg().isVirtual())
continue;
auto Reg = Op.getReg();
for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
if (DefCycle.contains(UserInstr.getParent()))
continue;
markDivergent(UserInstr);
recordTemporalDivergence(Reg, &UserInstr, &DefCycle);
}
}
}
template <>
bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
const MachineOperand &U) const {
if (!U.isReg())
return false;
auto Reg = U.getReg();
if (isDivergent(Reg))
return true;
const auto &RegInfo = F.getRegInfo();
auto *Def = RegInfo.getOneDef(Reg);
if (!Def)
return true;
auto *DefInstr = Def->getParent();
auto *UseInstr = U.getParent();
return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
}
template <>
bool GenericUniformityAnalysisImpl<MachineSSAContext>::isCustomUniform(
const MachineInstr &MI) const {
llvm_unreachable("no MIR instructions use Custom uniformity yet");
}
// This ensures explicit instantiation of
// GenericUniformityAnalysisImpl::ImplDeleter::operator()
template class llvm::GenericUniformityInfo<MachineSSAContext>;
template struct llvm::GenericUniformityAnalysisImplDeleter<
llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
MachineUniformityInfo llvm::computeMachineUniformityInfo(
MachineFunction &F, const MachineCycleInfo &cycleInfo,
const MachineDominatorTree &domTree, bool HasBranchDivergence) {
assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
MachineUniformityInfo UI(domTree, cycleInfo);
if (HasBranchDivergence)
UI.compute();
return UI;
}
namespace {
class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
public:
static char ID;
MachineUniformityInfoPrinterPass();
bool runOnMachineFunction(MachineFunction &F) override;
void getAnalysisUsage(AnalysisUsage &AU) const override;
};
} // namespace
AnalysisKey MachineUniformityAnalysis::Key;
MachineUniformityAnalysis::Result
MachineUniformityAnalysis::run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM) {
auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
.getManager();
auto &F = MF.getFunction();
auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
return computeMachineUniformityInfo(MF, CI, DomTree,
TTI.hasBranchDivergence(&F));
}
PreservedAnalyses
MachineUniformityPrinterPass::run(MachineFunction &MF,
MachineFunctionAnalysisManager &MFAM) {
auto &MUI = MFAM.getResult<MachineUniformityAnalysis>(MF);
OS << "MachineUniformityInfo for function: ";
MF.getFunction().printAsOperand(OS, /*PrintType=*/false);
OS << '\n';
MUI.print(OS);
return PreservedAnalyses::all();
}
char MachineUniformityAnalysisPass::ID = 0;
MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
: MachineFunctionPass(ID) {}
INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
"Machine Uniformity Info Analysis", false, true)
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
"Machine Uniformity Info Analysis", false, true)
void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
AU.addRequired<MachineDominatorTreeWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
// FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
// default NoTTI
UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
return false;
}
void MachineUniformityAnalysisPass::print(raw_ostream &OS,
const Module *) const {
OS << "MachineUniformityInfo for function: ";
UI.getFunction().getFunction().printAsOperand(OS, /*PrintType=*/false);
OS << '\n';
UI.print(OS);
}
char MachineUniformityInfoPrinterPass::ID = 0;
MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
: MachineFunctionPass(ID) {}
INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
"print-machine-uniformity",
"Print Machine Uniformity Info Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
"print-machine-uniformity",
"Print Machine Uniformity Info Analysis", true, true)
void MachineUniformityInfoPrinterPass::getAnalysisUsage(
AnalysisUsage &AU) const {
AU.setPreservesAll();
AU.addRequired<MachineUniformityAnalysisPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
MachineFunction &F) {
auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
UI.print(errs());
return false;
}