llvm-project/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
Kazu Hirata 82d5dd28b4
[RISCV] Remove unused includes (NFC) (#115814)
Identified with misc-include-cleaner.
2024-11-11 22:54:54 -08:00

885 lines
29 KiB
C++

//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//
// This pass reduces the VL where possible at the MI level, before VSETVLI
// instructions are inserted.
//
// The purpose of this optimization is to make the VL argument, for instructions
// that have a VL argument, as small as possible. This is implemented by
// visiting each instruction in reverse order and checking that if it has a VL
// argument, whether the VL can be reduced.
//
//===---------------------------------------------------------------------===//
#include "RISCV.h"
#include "RISCVSubtarget.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/InitializePasses.h"
using namespace llvm;
#define DEBUG_TYPE "riscv-vl-optimizer"
#define PASS_NAME "RISC-V VL Optimizer"
namespace {
class RISCVVLOptimizer : public MachineFunctionPass {
const MachineRegisterInfo *MRI;
const MachineDominatorTree *MDT;
public:
static char ID;
RISCVVLOptimizer() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<MachineDominatorTreeWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
StringRef getPassName() const override { return PASS_NAME; }
private:
bool checkUsers(const MachineOperand *&CommonVL, MachineInstr &MI);
bool tryReduceVL(MachineInstr &MI);
bool isCandidate(const MachineInstr &MI) const;
};
} // end anonymous namespace
char RISCVVLOptimizer::ID = 0;
INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
FunctionPass *llvm::createRISCVVLOptimizerPass() {
return new RISCVVLOptimizer();
}
/// Return true if R is a physical or virtual vector register, false otherwise.
static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
if (R.isPhysical())
return RISCV::VRRegClass.contains(R);
const TargetRegisterClass *RC = MRI->getRegClass(R);
return RISCVRI::isVRegClass(RC->TSFlags);
}
/// Represents the EMUL and EEW of a MachineOperand.
struct OperandInfo {
enum class State {
Unknown,
Known,
} S;
// Represent as 1,2,4,8, ... and fractional indicator. This is because
// EMUL can take on values that don't map to RISCVII::VLMUL values exactly.
// For example, a mask operand can have an EMUL less than MF8.
std::optional<std::pair<unsigned, bool>> EMUL;
unsigned Log2EEW;
OperandInfo(RISCVII::VLMUL EMUL, unsigned Log2EEW)
: S(State::Known), EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {
}
OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
: S(State::Known), EMUL(EMUL), Log2EEW(Log2EEW) {}
OperandInfo() : S(State::Unknown) {}
bool isUnknown() const { return S == State::Unknown; }
bool isKnown() const { return S == State::Known; }
static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
assert(A.isKnown() && B.isKnown() && "Both operands must be known");
return A.Log2EEW == B.Log2EEW && A.EMUL->first == B.EMUL->first &&
A.EMUL->second == B.EMUL->second;
}
void print(raw_ostream &OS) const {
if (isUnknown()) {
OS << "Unknown";
return;
}
assert(EMUL && "Expected EMUL to have value");
OS << "EMUL: m";
if (EMUL->second)
OS << "f";
OS << EMUL->first;
OS << ", EEW: " << (1 << Log2EEW);
}
};
LLVM_ATTRIBUTE_UNUSED
static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
OI.print(OS);
return OS;
}
namespace llvm {
namespace RISCVVType {
/// Return the RISCVII::VLMUL that is two times VLMul.
/// Precondition: VLMul is not LMUL_RESERVED or LMUL_8.
static RISCVII::VLMUL twoTimesVLMUL(RISCVII::VLMUL VLMul) {
switch (VLMul) {
case RISCVII::VLMUL::LMUL_F8:
return RISCVII::VLMUL::LMUL_F4;
case RISCVII::VLMUL::LMUL_F4:
return RISCVII::VLMUL::LMUL_F2;
case RISCVII::VLMUL::LMUL_F2:
return RISCVII::VLMUL::LMUL_1;
case RISCVII::VLMUL::LMUL_1:
return RISCVII::VLMUL::LMUL_2;
case RISCVII::VLMUL::LMUL_2:
return RISCVII::VLMUL::LMUL_4;
case RISCVII::VLMUL::LMUL_4:
return RISCVII::VLMUL::LMUL_8;
case RISCVII::VLMUL::LMUL_8:
default:
llvm_unreachable("Could not multiply VLMul by 2");
}
}
/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
/// SEW are from the TSFlags of MI.
static std::pair<unsigned, bool>
getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
RISCVII::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
unsigned MISEW = 1 << MILog2SEW;
unsigned EEW = 1 << Log2EEW;
// Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
// to put fraction in simplest form.
unsigned Num = EEW, Denom = MISEW;
int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
: std::gcd(Num * MILMUL, Denom);
Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
}
} // end namespace RISCVVType
} // end namespace llvm
/// Dest has EEW=SEW and EMUL=LMUL. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
/// Source has EMUL=(EEW/SEW)*LMUL. LMUL and SEW comes from TSFlags of MI.
static OperandInfo getIntegerExtensionOperandInfo(unsigned Factor,
const MachineInstr &MI,
const MachineOperand &MO) {
RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
if (MO.getOperandNo() == 0)
return OperandInfo(MIVLMul, MILog2SEW);
unsigned MISEW = 1 << MILog2SEW;
unsigned EEW = MISEW / Factor;
unsigned Log2EEW = Log2_32(EEW);
return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(Log2EEW, MI),
Log2EEW);
}
/// Check whether MO is a mask operand of MI.
static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
const MachineRegisterInfo *MRI) {
if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI))
return false;
const MCInstrDesc &Desc = MI.getDesc();
return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
}
/// Return the OperandInfo for MO, which is an operand of MI.
static OperandInfo getOperandInfo(const MachineInstr &MI,
const MachineOperand &MO,
const MachineRegisterInfo *MRI) {
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
assert(RVV && "Could not find MI in PseudoTable");
// MI has a VLMUL and SEW associated with it. The RVV specification defines
// the LMUL and SEW of each operand and definition in relation to MI.VLMUL and
// MI.SEW.
RISCVII::VLMUL MIVLMul = RISCVII::getLMul(MI.getDesc().TSFlags);
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc());
// We bail out early for instructions that have passthru with non NoRegister,
// which means they are using TU policy. We are not interested in these
// since they must preserve the entire register content.
if (HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs() &&
(MO.getReg() != RISCV::NoRegister))
return {};
bool IsMODef = MO.getOperandNo() == 0;
// All mask operands have EEW=1, EMUL=(EEW/SEW)*LMUL
if (isMaskOperand(MI, MO, MRI))
return OperandInfo(RISCVVType::getEMULEqualsEEWDivSEWTimesLMUL(0, MI), 0);
// switch against BaseInstr to reduce number of cases that need to be
// considered.
switch (RVV->BaseInstr) {
// 6. Configuration-Setting Instructions
// Configuration setting instructions do not read or write vector registers
case RISCV::VSETIVLI:
case RISCV::VSETVL:
case RISCV::VSETVLI:
llvm_unreachable("Configuration setting instructions do not read or write "
"vector registers");
// Vector Integer Arithmetic Instructions
// Vector Single-Width Integer Add and Subtract
case RISCV::VADD_VI:
case RISCV::VADD_VV:
case RISCV::VADD_VX:
case RISCV::VSUB_VV:
case RISCV::VSUB_VX:
case RISCV::VRSUB_VI:
case RISCV::VRSUB_VX:
// Vector Bitwise Logical Instructions
// Vector Single-Width Shift Instructions
// EEW=SEW. EMUL=LMUL.
case RISCV::VAND_VI:
case RISCV::VAND_VV:
case RISCV::VAND_VX:
case RISCV::VOR_VI:
case RISCV::VOR_VV:
case RISCV::VOR_VX:
case RISCV::VXOR_VI:
case RISCV::VXOR_VV:
case RISCV::VXOR_VX:
case RISCV::VSLL_VI:
case RISCV::VSLL_VV:
case RISCV::VSLL_VX:
case RISCV::VSRL_VI:
case RISCV::VSRL_VV:
case RISCV::VSRL_VX:
case RISCV::VSRA_VI:
case RISCV::VSRA_VV:
case RISCV::VSRA_VX:
// Vector Integer Min/Max Instructions
// EEW=SEW. EMUL=LMUL.
case RISCV::VMINU_VV:
case RISCV::VMINU_VX:
case RISCV::VMIN_VV:
case RISCV::VMIN_VX:
case RISCV::VMAXU_VV:
case RISCV::VMAXU_VX:
case RISCV::VMAX_VV:
case RISCV::VMAX_VX:
// Vector Single-Width Integer Multiply Instructions
// Source and Dest EEW=SEW and EMUL=LMUL.
case RISCV::VMUL_VV:
case RISCV::VMUL_VX:
case RISCV::VMULH_VV:
case RISCV::VMULH_VX:
case RISCV::VMULHU_VV:
case RISCV::VMULHU_VX:
case RISCV::VMULHSU_VV:
case RISCV::VMULHSU_VX:
// Vector Integer Divide Instructions
// EEW=SEW. EMUL=LMUL.
case RISCV::VDIVU_VV:
case RISCV::VDIVU_VX:
case RISCV::VDIV_VV:
case RISCV::VDIV_VX:
case RISCV::VREMU_VV:
case RISCV::VREMU_VX:
case RISCV::VREM_VV:
case RISCV::VREM_VX:
// Vector Single-Width Integer Multiply-Add Instructions
// EEW=SEW. EMUL=LMUL.
case RISCV::VMACC_VV:
case RISCV::VMACC_VX:
case RISCV::VNMSAC_VV:
case RISCV::VNMSAC_VX:
case RISCV::VMADD_VV:
case RISCV::VMADD_VX:
case RISCV::VNMSUB_VV:
case RISCV::VNMSUB_VX:
// Vector Integer Merge Instructions
// EEW=SEW and EMUL=LMUL, except the mask operand has EEW=1 and EMUL=
// (EEW/SEW)*LMUL. Mask operand is handled before this switch.
case RISCV::VMERGE_VIM:
case RISCV::VMERGE_VVM:
case RISCV::VMERGE_VXM:
// Vector Integer Move Instructions
// Vector Fixed-Point Arithmetic Instructions
// Vector Single-Width Saturating Add and Subtract
// Vector Single-Width Averaging Add and Subtract
// EEW=SEW. EMUL=LMUL.
case RISCV::VMV_V_I:
case RISCV::VMV_V_V:
case RISCV::VMV_V_X:
case RISCV::VSADDU_VI:
case RISCV::VSADDU_VV:
case RISCV::VSADDU_VX:
case RISCV::VSADD_VI:
case RISCV::VSADD_VV:
case RISCV::VSADD_VX:
case RISCV::VSSUBU_VV:
case RISCV::VSSUBU_VX:
case RISCV::VSSUB_VV:
case RISCV::VSSUB_VX:
case RISCV::VAADDU_VV:
case RISCV::VAADDU_VX:
case RISCV::VAADD_VV:
case RISCV::VAADD_VX:
case RISCV::VASUBU_VV:
case RISCV::VASUBU_VX:
case RISCV::VASUB_VV:
case RISCV::VASUB_VX:
// Vector Single-Width Scaling Shift Instructions
// EEW=SEW. EMUL=LMUL.
case RISCV::VSSRL_VI:
case RISCV::VSSRL_VV:
case RISCV::VSSRL_VX:
case RISCV::VSSRA_VI:
case RISCV::VSSRA_VV:
case RISCV::VSSRA_VX:
// Vector Permutation Instructions
// Integer Scalar Move Instructions
// Floating-Point Scalar Move Instructions
// EMUL=LMUL. EEW=SEW.
case RISCV::VMV_X_S:
case RISCV::VMV_S_X:
case RISCV::VFMV_F_S:
case RISCV::VFMV_S_F:
// Vector Slide Instructions
// EMUL=LMUL. EEW=SEW.
case RISCV::VSLIDEUP_VI:
case RISCV::VSLIDEUP_VX:
case RISCV::VSLIDEDOWN_VI:
case RISCV::VSLIDEDOWN_VX:
case RISCV::VSLIDE1UP_VX:
case RISCV::VFSLIDE1UP_VF:
case RISCV::VSLIDE1DOWN_VX:
case RISCV::VFSLIDE1DOWN_VF:
// Vector Register Gather Instructions
// EMUL=LMUL. EEW=SEW. For mask operand, EMUL=1 and EEW=1.
case RISCV::VRGATHER_VI:
case RISCV::VRGATHER_VV:
case RISCV::VRGATHER_VX:
// Vector Compress Instruction
// EMUL=LMUL. EEW=SEW.
case RISCV::VCOMPRESS_VM:
return OperandInfo(MIVLMul, MILog2SEW);
// Vector Widening Integer Add/Subtract
// Def uses EEW=2*SEW and EMUL=2*LMUL. Operands use EEW=SEW and EMUL=LMUL.
case RISCV::VWADDU_VV:
case RISCV::VWADDU_VX:
case RISCV::VWSUBU_VV:
case RISCV::VWSUBU_VX:
case RISCV::VWADD_VV:
case RISCV::VWADD_VX:
case RISCV::VWSUB_VV:
case RISCV::VWSUB_VX:
case RISCV::VWSLL_VI:
// Vector Widening Integer Multiply Instructions
// Source and Destination EMUL=LMUL. Destination EEW=2*SEW. Source EEW=SEW.
case RISCV::VWMUL_VV:
case RISCV::VWMUL_VX:
case RISCV::VWMULSU_VV:
case RISCV::VWMULSU_VX:
case RISCV::VWMULU_VV:
case RISCV::VWMULU_VX: {
unsigned Log2EEW = IsMODef ? MILog2SEW + 1 : MILog2SEW;
RISCVII::VLMUL EMUL =
IsMODef ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul;
return OperandInfo(EMUL, Log2EEW);
}
// Def and Op1 uses EEW=2*SEW and EMUL=2*LMUL. Op2 uses EEW=SEW and EMUL=LMUL
case RISCV::VWADDU_WV:
case RISCV::VWADDU_WX:
case RISCV::VWSUBU_WV:
case RISCV::VWSUBU_WX:
case RISCV::VWADD_WV:
case RISCV::VWADD_WX:
case RISCV::VWSUB_WV:
case RISCV::VWSUB_WX:
// Vector Widening Integer Multiply-Add Instructions
// Destination EEW=2*SEW and EMUL=2*LMUL. Source EEW=SEW and EMUL=LMUL.
// Even though the add is a 2*SEW addition, the operands of the add are the
// Dest which is 2*SEW and the result of the multiply which is 2*SEW.
case RISCV::VWMACCU_VV:
case RISCV::VWMACCU_VX:
case RISCV::VWMACC_VV:
case RISCV::VWMACC_VX:
case RISCV::VWMACCSU_VV:
case RISCV::VWMACCSU_VX:
case RISCV::VWMACCUS_VX: {
bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
bool TwoTimes = IsMODef || IsOp1;
unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW;
RISCVII::VLMUL EMUL =
TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul;
return OperandInfo(EMUL, Log2EEW);
}
// Vector Integer Extension
case RISCV::VZEXT_VF2:
case RISCV::VSEXT_VF2:
return getIntegerExtensionOperandInfo(2, MI, MO);
case RISCV::VZEXT_VF4:
case RISCV::VSEXT_VF4:
return getIntegerExtensionOperandInfo(4, MI, MO);
case RISCV::VZEXT_VF8:
case RISCV::VSEXT_VF8:
return getIntegerExtensionOperandInfo(8, MI, MO);
// Vector Narrowing Integer Right Shift Instructions
// Destination EEW=SEW and EMUL=LMUL, Op 1 has EEW=2*SEW EMUL=2*LMUL. Op2 has
// EEW=SEW EMUL=LMUL.
case RISCV::VNSRL_WX:
case RISCV::VNSRL_WI:
case RISCV::VNSRL_WV:
case RISCV::VNSRA_WI:
case RISCV::VNSRA_WV:
case RISCV::VNSRA_WX:
// Vector Narrowing Fixed-Point Clip Instructions
// Destination and Op1 EEW=SEW and EMUL=LMUL. Op2 EEW=2*SEW and EMUL=2*LMUL
case RISCV::VNCLIPU_WI:
case RISCV::VNCLIPU_WV:
case RISCV::VNCLIPU_WX:
case RISCV::VNCLIP_WI:
case RISCV::VNCLIP_WV:
case RISCV::VNCLIP_WX: {
bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
bool TwoTimes = IsOp1;
unsigned Log2EEW = TwoTimes ? MILog2SEW + 1 : MILog2SEW;
RISCVII::VLMUL EMUL =
TwoTimes ? RISCVVType::twoTimesVLMUL(MIVLMul) : MIVLMul;
return OperandInfo(EMUL, Log2EEW);
}
default:
return {};
}
}
/// Return true if this optimization should consider MI for VL reduction. This
/// white-list approach simplifies this optimization for instructions that may
/// have more complex semantics with relation to how it uses VL.
static bool isSupportedInstr(const MachineInstr &MI) {
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
if (!RVV)
return false;
switch (RVV->BaseInstr) {
// Vector Single-Width Integer Add and Subtract
case RISCV::VADD_VI:
case RISCV::VADD_VV:
case RISCV::VADD_VX:
case RISCV::VSUB_VV:
case RISCV::VSUB_VX:
case RISCV::VRSUB_VI:
case RISCV::VRSUB_VX:
// Vector Widening Integer Add/Subtract
case RISCV::VWADDU_VV:
case RISCV::VWADDU_VX:
case RISCV::VWSUBU_VV:
case RISCV::VWSUBU_VX:
case RISCV::VWADD_VV:
case RISCV::VWADD_VX:
case RISCV::VWSUB_VV:
case RISCV::VWSUB_VX:
case RISCV::VWADDU_WV:
case RISCV::VWADDU_WX:
case RISCV::VWSUBU_WV:
case RISCV::VWSUBU_WX:
case RISCV::VWADD_WV:
case RISCV::VWADD_WX:
case RISCV::VWSUB_WV:
case RISCV::VWSUB_WX:
// Vector Integer Extension
case RISCV::VZEXT_VF2:
case RISCV::VSEXT_VF2:
case RISCV::VZEXT_VF4:
case RISCV::VSEXT_VF4:
case RISCV::VZEXT_VF8:
case RISCV::VSEXT_VF8:
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
// FIXME: Add support
// Vector Bitwise Logical Instructions
// FIXME: Add support
// Vector Single-Width Shift Instructions
// FIXME: Add support
case RISCV::VSLL_VI:
// Vector Narrowing Integer Right Shift Instructions
// FIXME: Add support
case RISCV::VNSRL_WI:
// Vector Integer Compare Instructions
// FIXME: Add support
// Vector Integer Min/Max Instructions
case RISCV::VMINU_VV:
case RISCV::VMINU_VX:
case RISCV::VMIN_VV:
case RISCV::VMIN_VX:
case RISCV::VMAXU_VV:
case RISCV::VMAXU_VX:
case RISCV::VMAX_VV:
case RISCV::VMAX_VX:
// Vector Single-Width Integer Multiply Instructions
case RISCV::VMUL_VV:
case RISCV::VMUL_VX:
case RISCV::VMULH_VV:
case RISCV::VMULH_VX:
case RISCV::VMULHU_VV:
case RISCV::VMULHU_VX:
case RISCV::VMULHSU_VV:
case RISCV::VMULHSU_VX:
// Vector Integer Divide Instructions
case RISCV::VDIVU_VV:
case RISCV::VDIVU_VX:
case RISCV::VDIV_VV:
case RISCV::VDIV_VX:
case RISCV::VREMU_VV:
case RISCV::VREMU_VX:
case RISCV::VREM_VV:
case RISCV::VREM_VX:
// Vector Widening Integer Multiply Instructions
case RISCV::VWMUL_VV:
case RISCV::VWMUL_VX:
case RISCV::VWMULSU_VV:
case RISCV::VWMULSU_VX:
case RISCV::VWMULU_VV:
case RISCV::VWMULU_VX:
// Vector Single-Width Integer Multiply-Add Instructions
// FIXME: Add support
// Vector Widening Integer Multiply-Add Instructions
// FIXME: Add support
case RISCV::VWMACC_VX:
case RISCV::VWMACCU_VX:
// Vector Integer Merge Instructions
// FIXME: Add support
// Vector Integer Move Instructions
// FIXME: Add support
case RISCV::VMV_V_I:
case RISCV::VMV_V_X:
// Vector Crypto
case RISCV::VWSLL_VI:
return true;
}
return false;
}
/// Return true if MO is a vector operand but is used as a scalar operand.
static bool isVectorOpUsedAsScalarOp(MachineOperand &MO) {
MachineInstr *MI = MO.getParent();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI->getOpcode());
if (!RVV)
return false;
switch (RVV->BaseInstr) {
// Reductions only use vs1[0] of vs1
case RISCV::VREDAND_VS:
case RISCV::VREDMAX_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS:
case RISCV::VWREDSUM_VS:
case RISCV::VWREDSUMU_VS:
case RISCV::VFREDMAX_VS:
case RISCV::VFREDMIN_VS:
case RISCV::VFREDOSUM_VS:
case RISCV::VFREDUSUM_VS:
case RISCV::VFWREDOSUM_VS:
case RISCV::VFWREDUSUM_VS:
return MO.getOperandNo() == 3;
default:
return false;
}
}
/// Return true if MI may read elements past VL.
static bool mayReadPastVL(const MachineInstr &MI) {
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
if (!RVV)
return true;
switch (RVV->BaseInstr) {
// vslidedown instructions may read elements past VL. They are handled
// according to current tail policy.
case RISCV::VSLIDEDOWN_VI:
case RISCV::VSLIDEDOWN_VX:
case RISCV::VSLIDE1DOWN_VX:
case RISCV::VFSLIDE1DOWN_VF:
// vrgather instructions may read the source vector at any index < VLMAX,
// regardless of VL.
case RISCV::VRGATHER_VI:
case RISCV::VRGATHER_VV:
case RISCV::VRGATHER_VX:
case RISCV::VRGATHEREI16_VV:
return true;
default:
return false;
}
}
bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
const MCInstrDesc &Desc = MI.getDesc();
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags))
return false;
if (MI.getNumDefs() != 1)
return false;
// If we're not using VLMAX, then we need to be careful whether we are using
// TA/TU when there is a non-undef Passthru. But when we are using VLMAX, it
// does not matter whether we are using TA/TU with a non-undef Passthru, since
// there are no tail elements to be perserved.
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = MI.getOperand(VLOpNum);
if (VLOp.isReg() || VLOp.getImm() != RISCV::VLMaxSentinel) {
// If MI has a non-undef passthru, we will not try to optimize it since
// that requires us to preserve tail elements according to TA/TU.
// Otherwise, The MI has an undef Passthru, so it doesn't matter whether we
// are using TA/TU.
bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
unsigned PassthruOpIdx = MI.getNumExplicitDefs();
if (HasPassthru &&
MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) {
LLVM_DEBUG(
dbgs() << " Not a candidate because it uses non-undef passthru"
" with non-VLMAX VL\n");
return false;
}
}
// If the VL is 1, then there is no need to reduce it. This is an
// optimization, not needed to preserve correctness.
if (VLOp.isImm() && VLOp.getImm() == 1) {
LLVM_DEBUG(dbgs() << " Not a candidate because VL is already 1\n");
return false;
}
// Some instructions that produce vectors have semantics that make it more
// difficult to determine whether the VL can be reduced. For example, some
// instructions, such as reductions, may write lanes past VL to a scalar
// register. Other instructions, such as some loads or stores, may write
// lower lanes using data from higher lanes. There may be other complex
// semantics not mentioned here that make it hard to determine whether
// the VL can be optimized. As a result, a white-list of supported
// instructions is used. Over time, more instructions cam be supported
// upon careful examination of their semantics under the logic in this
// optimization.
// TODO: Use a better approach than a white-list, such as adding
// properties to instructions using something like TSFlags.
if (!isSupportedInstr(MI)) {
LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n");
return false;
}
LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
return true;
}
bool RISCVVLOptimizer::checkUsers(const MachineOperand *&CommonVL,
MachineInstr &MI) {
// FIXME: Avoid visiting each user for each time we visit something on the
// worklist, combined with an extra visit from the outer loop. Restructure
// along lines of an instcombine style worklist which integrates the outer
// pass.
bool CanReduceVL = true;
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg())) {
const MachineInstr &UserMI = *UserOp.getParent();
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
// Instructions like reductions may use a vector register as a scalar
// register. In this case, we should treat it like a scalar register which
// does not impact the decision on whether to optimize VL.
if (isVectorOpUsedAsScalarOp(UserOp)) {
[[maybe_unused]] Register R = UserOp.getReg();
[[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R);
assert(RISCV::VRRegClass.hasSubClassEq(RC) &&
"Expect LMUL 1 register class for vector as scalar operands!");
LLVM_DEBUG(dbgs() << " Use this operand as a scalar operand\n");
continue;
}
if (mayReadPastVL(UserMI)) {
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
CanReduceVL = false;
break;
}
// Tied operands might pass through.
if (UserOp.isTied()) {
LLVM_DEBUG(dbgs() << " Abort because user used as tied operand\n");
CanReduceVL = false;
break;
}
const MCInstrDesc &Desc = UserMI.getDesc();
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort due to lack of VL or SEW, assume that"
" use VLMAX\n");
CanReduceVL = false;
break;
}
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
// Looking for an immediate or a register VL that isn't X0.
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
"Did not expect X0 VL");
if (!CommonVL) {
CommonVL = &VLOp;
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
} else if (!CommonVL->isIdenticalTo(VLOp)) {
// FIXME: This check requires all users to have the same VL. We can relax
// this and get the largest VL amongst all users.
LLVM_DEBUG(dbgs() << " Abort because users have different VL\n");
CanReduceVL = false;
break;
}
// The SEW and LMUL of destination and source registers need to match.
// We know that MI DEF is a vector register, because that was the guard
// to call this function.
assert(isVectorRegClass(UserMI.getOperand(0).getReg(), MRI) &&
"Expected DEF and USE to be vector registers");
OperandInfo ConsumerInfo = getOperandInfo(UserMI, UserOp, MRI);
OperandInfo ProducerInfo = getOperandInfo(MI, MI.getOperand(0), MRI);
if (ConsumerInfo.isUnknown() || ProducerInfo.isUnknown() ||
!OperandInfo::EMULAndEEWAreEqual(ConsumerInfo, ProducerInfo)) {
LLVM_DEBUG(dbgs() << " Abort due to incompatible or unknown "
"information for EMUL or EEW.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
CanReduceVL = false;
break;
}
}
return CanReduceVL;
}
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &OrigMI) {
SetVector<MachineInstr *> Worklist;
Worklist.insert(&OrigMI);
bool MadeChange = false;
while (!Worklist.empty()) {
MachineInstr &MI = *Worklist.pop_back_val();
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
const MachineOperand *CommonVL = nullptr;
bool CanReduceVL = true;
if (isVectorRegClass(MI.getOperand(0).getReg(), MRI))
CanReduceVL = checkUsers(CommonVL, MI);
if (!CanReduceVL || !CommonVL)
continue;
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
"Expected VL to be an Imm or virtual Reg");
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
MachineOperand &VLOp = MI.getOperand(VLOpNum);
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
continue;
}
if (CommonVL->isImm()) {
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
<< CommonVL->getImm() << " for " << MI << "\n");
VLOp.ChangeToImmediate(CommonVL->getImm());
} else {
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
if (!MDT->dominates(VLMI, &MI))
continue;
LLVM_DEBUG(
dbgs() << " Reduce VL from " << VLOp << " to "
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
<< " for " << MI << "\n");
// All our checks passed. We can reduce VL.
VLOp.ChangeToRegister(CommonVL->getReg(), false);
}
MadeChange = true;
// Now add all inputs to this instruction to the worklist.
for (auto &Op : MI.operands()) {
if (!Op.isReg() || !Op.isUse() || !Op.getReg().isVirtual())
continue;
if (!isVectorRegClass(Op.getReg(), MRI))
continue;
MachineInstr *DefMI = MRI->getVRegDef(Op.getReg());
if (!isCandidate(*DefMI))
continue;
Worklist.insert(DefMI);
}
}
return MadeChange;
}
bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()))
return false;
MRI = &MF.getRegInfo();
MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
if (!ST.hasVInstructions())
return false;
bool MadeChange = false;
for (MachineBasicBlock &MBB : MF) {
// Visit instructions in reverse order.
for (auto &MI : make_range(MBB.rbegin(), MBB.rend())) {
if (!isCandidate(MI))
continue;
MadeChange |= tryReduceVL(MI);
}
}
return MadeChange;
}