llvm-project/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
Luke Lau 47b756a5a6
[RISCV] Only reduce VLs of instructions with demanded VLs (#168693)
In RISCVVLOptimizer we first compute all the demanded VLs, then we walk
backwards through the function and try to reduce any VLs.

We don't actually need to walk backwards anymore since after #124530 the
order in which we modify the instructions doesn't matter.

This patch changes it to just iterate over the instructions with a
demanded VL computed, which means we don't iterate over scalar
instructions etc.

This also fixes #168665, where we triggered an assert on instructions
with a dead $vxsat implicit-def:

dead %x:vr = PseudoVSADDU_VV_M1 $noreg, $noreg, $noreg, -1, 3 /* e8 */,
0 /* tu, mu */, implicit-def dead $vxsat

Because $vxsat is a reserved register, DeadMachineInstructionElim won't
remove it and the instruction makes it to RISCVVLOptimizer.

And because the def of %x is dead, we don't reach this instruction in
the dataflow analysis. This instruction returns true for isCandidate, so
we would try to lookup its demanded VL which doesn't exist and assert.
But with this patch we don't try to reduce instructions that aren't in
DemandedVLs, which fixes the crash.
2025-11-20 03:49:59 +00:00

1731 lines
54 KiB
C++

//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//
// This pass reduces the VL where possible at the MI level, before VSETVLI
// instructions are inserted.
//
// The purpose of this optimization is to make the VL argument, for instructions
// that have a VL argument, as small as possible.
//
// This is split into a sparse dataflow analysis where we determine what VL is
// demanded by each instruction first, and then afterwards try to reduce the VL
// of each instruction if it demands less than its VL operand.
//
// The analysis is explained in more detail in the 2025 EuroLLVM Developers'
// Meeting talk "Accidental Dataflow Analysis: Extending the RISC-V VL
// Optimizer", which is available on YouTube at
// https://www.youtube.com/watch?v=Mfb5fRSdJAc
//
// The slides for the talk are available at
// https://llvm.org/devmtg/2025-04/slides/technical_talk/lau_accidental_dataflow.pdf
//
//===---------------------------------------------------------------------===//
#include "RISCV.h"
#include "RISCVSubtarget.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/InitializePasses.h"
using namespace llvm;
#define DEBUG_TYPE "riscv-vl-optimizer"
#define PASS_NAME "RISC-V VL Optimizer"
namespace {
/// Wrapper around MachineOperand that defaults to immediate 0.
struct DemandedVL {
MachineOperand VL;
DemandedVL() : VL(MachineOperand::CreateImm(0)) {}
DemandedVL(MachineOperand VL) : VL(VL) {}
static DemandedVL vlmax() {
return DemandedVL(MachineOperand::CreateImm(RISCV::VLMaxSentinel));
}
bool operator!=(const DemandedVL &Other) const {
return !VL.isIdenticalTo(Other.VL);
}
DemandedVL max(const DemandedVL &X) const {
if (RISCV::isVLKnownLE(VL, X.VL))
return X;
if (RISCV::isVLKnownLE(X.VL, VL))
return *this;
return DemandedVL::vlmax();
}
};
class RISCVVLOptimizer : public MachineFunctionPass {
MachineRegisterInfo *MRI;
const MachineDominatorTree *MDT;
const TargetInstrInfo *TII;
public:
static char ID;
RISCVVLOptimizer() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<MachineDominatorTreeWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
StringRef getPassName() const override { return PASS_NAME; }
private:
DemandedVL getMinimumVLForUser(const MachineOperand &UserOp) const;
/// Returns true if the users of \p MI have compatible EEWs and SEWs.
bool checkUsers(const MachineInstr &MI) const;
bool tryReduceVL(MachineInstr &MI, MachineOperand VL) const;
bool isCandidate(const MachineInstr &MI) const;
void transfer(const MachineInstr &MI);
/// For a given instruction, records what elements of it are demanded by
/// downstream users.
DenseMap<const MachineInstr *, DemandedVL> DemandedVLs;
SetVector<const MachineInstr *> Worklist;
/// \returns all vector virtual registers that \p MI uses.
auto virtual_vec_uses(const MachineInstr &MI) const {
return make_filter_range(MI.uses(), [this](const MachineOperand &MO) {
return MO.isReg() && MO.getReg().isVirtual() &&
RISCVRegisterInfo::isRVVRegClass(MRI->getRegClass(MO.getReg()));
});
}
};
/// Represents the EMUL and EEW of a MachineOperand.
struct OperandInfo {
// Represent as 1,2,4,8, ... and fractional indicator. This is because
// EMUL can take on values that don't map to RISCVVType::VLMUL values exactly.
// For example, a mask operand can have an EMUL less than MF8.
// If nullopt, then EMUL isn't used (i.e. only a single scalar is read).
std::optional<std::pair<unsigned, bool>> EMUL;
unsigned Log2EEW;
OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW)
: EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {}
OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
: EMUL(EMUL), Log2EEW(Log2EEW) {}
OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}
OperandInfo() = delete;
/// Return true if the EMUL and EEW produced by \p Def is compatible with the
/// EMUL and EEW used by \p User.
static bool areCompatible(const OperandInfo &Def, const OperandInfo &User) {
if (Def.Log2EEW != User.Log2EEW)
return false;
if (User.EMUL && Def.EMUL != User.EMUL)
return false;
return true;
}
void print(raw_ostream &OS) const {
if (EMUL) {
OS << "EMUL: m";
if (EMUL->second)
OS << "f";
OS << EMUL->first;
} else
OS << "EMUL: none\n";
OS << ", EEW: " << (1 << Log2EEW);
}
};
} // end anonymous namespace
char RISCVVLOptimizer::ID = 0;
INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
FunctionPass *llvm::createRISCVVLOptimizerPass() {
return new RISCVVLOptimizer();
}
[[maybe_unused]]
static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
OI.print(OS);
return OS;
}
[[maybe_unused]]
static raw_ostream &operator<<(raw_ostream &OS,
const std::optional<OperandInfo> &OI) {
if (OI)
OI->print(OS);
else
OS << "nullopt";
return OS;
}
/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
/// SEW are from the TSFlags of MI.
static std::pair<unsigned, bool>
getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
// Mask instructions will have 0 as the SEW operand. But the LMUL of these
// instructions is calculated is as if the SEW operand was 3 (e8).
if (MILog2SEW == 0)
MILog2SEW = 3;
unsigned MISEW = 1 << MILog2SEW;
unsigned EEW = 1 << Log2EEW;
// Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
// to put fraction in simplest form.
unsigned Num = EEW, Denom = MISEW;
int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
: std::gcd(Num * MILMUL, Denom);
Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
}
/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
/// SEW comes from TSFlags of MI.
static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
const MachineInstr &MI,
const MachineOperand &MO) {
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
if (MO.getOperandNo() == 0)
return MILog2SEW;
unsigned MISEW = 1 << MILog2SEW;
unsigned EEW = MISEW / Factor;
unsigned Log2EEW = Log2_32(EEW);
return Log2EEW;
}
#define VSEG_CASES(Prefix, EEW) \
RISCV::Prefix##SEG2E##EEW##_V: \
case RISCV::Prefix##SEG3E##EEW##_V: \
case RISCV::Prefix##SEG4E##EEW##_V: \
case RISCV::Prefix##SEG5E##EEW##_V: \
case RISCV::Prefix##SEG6E##EEW##_V: \
case RISCV::Prefix##SEG7E##EEW##_V: \
case RISCV::Prefix##SEG8E##EEW##_V
#define VSSEG_CASES(EEW) VSEG_CASES(VS, EEW)
#define VSSSEG_CASES(EEW) VSEG_CASES(VSS, EEW)
#define VSUXSEG_CASES(EEW) VSEG_CASES(VSUX, I##EEW)
#define VSOXSEG_CASES(EEW) VSEG_CASES(VSOX, I##EEW)
static std::optional<unsigned> getOperandLog2EEW(const MachineOperand &MO) {
const MachineInstr &MI = *MO.getParent();
const MCInstrDesc &Desc = MI.getDesc();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
assert(RVV && "Could not find MI in PseudoTable");
// MI has a SEW associated with it. The RVV specification defines
// the EEW of each operand and definition in relation to MI.SEW.
unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(Desc)).getImm();
const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
const bool IsTied = RISCVII::isTiedPseudo(Desc.TSFlags);
bool IsMODef = MO.getOperandNo() == 0 ||
(HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs());
// All mask operands have EEW=1
const MCOperandInfo &Info = Desc.operands()[MO.getOperandNo()];
if (Info.OperandType == MCOI::OPERAND_REGISTER &&
Info.RegClass == RISCV::VMV0RegClassID)
return 0;
// switch against BaseInstr to reduce number of cases that need to be
// considered.
switch (RVV->BaseInstr) {
// 6. Configuration-Setting Instructions
// Configuration setting instructions do not read or write vector registers
case RISCV::VSETIVLI:
case RISCV::VSETVL:
case RISCV::VSETVLI:
llvm_unreachable("Configuration setting instructions do not read or write "
"vector registers");
// Vector Loads and Stores
// Vector Unit-Stride Instructions
// Vector Strided Instructions
/// Dest EEW encoded in the instruction
case RISCV::VLM_V:
case RISCV::VSM_V:
return 0;
case RISCV::VLE8_V:
case RISCV::VSE8_V:
case RISCV::VLSE8_V:
case RISCV::VSSE8_V:
case VSSEG_CASES(8):
case VSSSEG_CASES(8):
return 3;
case RISCV::VLE16_V:
case RISCV::VSE16_V:
case RISCV::VLSE16_V:
case RISCV::VSSE16_V:
case VSSEG_CASES(16):
case VSSSEG_CASES(16):
return 4;
case RISCV::VLE32_V:
case RISCV::VSE32_V:
case RISCV::VLSE32_V:
case RISCV::VSSE32_V:
case VSSEG_CASES(32):
case VSSSEG_CASES(32):
return 5;
case RISCV::VLE64_V:
case RISCV::VSE64_V:
case RISCV::VLSE64_V:
case RISCV::VSSE64_V:
case VSSEG_CASES(64):
case VSSSEG_CASES(64):
return 6;
// Vector Indexed Instructions
// vs(o|u)xei<eew>.v
// Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>.
case RISCV::VLUXEI8_V:
case RISCV::VLOXEI8_V:
case RISCV::VSUXEI8_V:
case RISCV::VSOXEI8_V:
case VSUXSEG_CASES(8):
case VSOXSEG_CASES(8): {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 3;
}
case RISCV::VLUXEI16_V:
case RISCV::VLOXEI16_V:
case RISCV::VSUXEI16_V:
case RISCV::VSOXEI16_V:
case VSUXSEG_CASES(16):
case VSOXSEG_CASES(16): {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 4;
}
case RISCV::VLUXEI32_V:
case RISCV::VLOXEI32_V:
case RISCV::VSUXEI32_V:
case RISCV::VSOXEI32_V:
case VSUXSEG_CASES(32):
case VSOXSEG_CASES(32): {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 5;
}
case RISCV::VLUXEI64_V:
case RISCV::VLOXEI64_V:
case RISCV::VSUXEI64_V:
case RISCV::VSOXEI64_V:
case VSUXSEG_CASES(64):
case VSOXSEG_CASES(64): {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 6;
}
// Vector Integer Arithmetic Instructions
// Vector Single-Width Integer Add and Subtract
case RISCV::VADD_VI:
case RISCV::VADD_VV:
case RISCV::VADD_VX:
case RISCV::VSUB_VV:
case RISCV::VSUB_VX:
case RISCV::VRSUB_VI:
case RISCV::VRSUB_VX:
// Vector Bitwise Logical Instructions
// Vector Single-Width Shift Instructions
// EEW=SEW.
case RISCV::VAND_VI:
case RISCV::VAND_VV:
case RISCV::VAND_VX:
case RISCV::VOR_VI:
case RISCV::VOR_VV:
case RISCV::VOR_VX:
case RISCV::VXOR_VI:
case RISCV::VXOR_VV:
case RISCV::VXOR_VX:
case RISCV::VSLL_VI:
case RISCV::VSLL_VV:
case RISCV::VSLL_VX:
case RISCV::VSRL_VI:
case RISCV::VSRL_VV:
case RISCV::VSRL_VX:
case RISCV::VSRA_VI:
case RISCV::VSRA_VV:
case RISCV::VSRA_VX:
// Vector Integer Min/Max Instructions
// EEW=SEW.
case RISCV::VMINU_VV:
case RISCV::VMINU_VX:
case RISCV::VMIN_VV:
case RISCV::VMIN_VX:
case RISCV::VMAXU_VV:
case RISCV::VMAXU_VX:
case RISCV::VMAX_VV:
case RISCV::VMAX_VX:
// Vector Single-Width Integer Multiply Instructions
// Source and Dest EEW=SEW.
case RISCV::VMUL_VV:
case RISCV::VMUL_VX:
case RISCV::VMULH_VV:
case RISCV::VMULH_VX:
case RISCV::VMULHU_VV:
case RISCV::VMULHU_VX:
case RISCV::VMULHSU_VV:
case RISCV::VMULHSU_VX:
// Vector Integer Divide Instructions
// EEW=SEW.
case RISCV::VDIVU_VV:
case RISCV::VDIVU_VX:
case RISCV::VDIV_VV:
case RISCV::VDIV_VX:
case RISCV::VREMU_VV:
case RISCV::VREMU_VX:
case RISCV::VREM_VV:
case RISCV::VREM_VX:
// Vector Single-Width Integer Multiply-Add Instructions
// EEW=SEW.
case RISCV::VMACC_VV:
case RISCV::VMACC_VX:
case RISCV::VNMSAC_VV:
case RISCV::VNMSAC_VX:
case RISCV::VMADD_VV:
case RISCV::VMADD_VX:
case RISCV::VNMSUB_VV:
case RISCV::VNMSUB_VX:
// Vector Integer Merge Instructions
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
// EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
// before this switch.
case RISCV::VMERGE_VIM:
case RISCV::VMERGE_VVM:
case RISCV::VMERGE_VXM:
case RISCV::VADC_VIM:
case RISCV::VADC_VVM:
case RISCV::VADC_VXM:
case RISCV::VSBC_VVM:
case RISCV::VSBC_VXM:
// Vector Integer Move Instructions
// Vector Fixed-Point Arithmetic Instructions
// Vector Single-Width Saturating Add and Subtract
// Vector Single-Width Averaging Add and Subtract
// EEW=SEW.
case RISCV::VMV_V_I:
case RISCV::VMV_V_V:
case RISCV::VMV_V_X:
case RISCV::VSADDU_VI:
case RISCV::VSADDU_VV:
case RISCV::VSADDU_VX:
case RISCV::VSADD_VI:
case RISCV::VSADD_VV:
case RISCV::VSADD_VX:
case RISCV::VSSUBU_VV:
case RISCV::VSSUBU_VX:
case RISCV::VSSUB_VV:
case RISCV::VSSUB_VX:
case RISCV::VAADDU_VV:
case RISCV::VAADDU_VX:
case RISCV::VAADD_VV:
case RISCV::VAADD_VX:
case RISCV::VASUBU_VV:
case RISCV::VASUBU_VX:
case RISCV::VASUB_VV:
case RISCV::VASUB_VX:
// Vector Single-Width Fractional Multiply with Rounding and Saturation
// EEW=SEW. The instruction produces 2*SEW product internally but
// saturates to fit into SEW bits.
case RISCV::VSMUL_VV:
case RISCV::VSMUL_VX:
// Vector Single-Width Scaling Shift Instructions
// EEW=SEW.
case RISCV::VSSRL_VI:
case RISCV::VSSRL_VV:
case RISCV::VSSRL_VX:
case RISCV::VSSRA_VI:
case RISCV::VSSRA_VV:
case RISCV::VSSRA_VX:
// Vector Permutation Instructions
// Integer Scalar Move Instructions
// Floating-Point Scalar Move Instructions
// EEW=SEW.
case RISCV::VMV_X_S:
case RISCV::VMV_S_X:
case RISCV::VFMV_F_S:
case RISCV::VFMV_S_F:
// Vector Slide Instructions
// EEW=SEW.
case RISCV::VSLIDEUP_VI:
case RISCV::VSLIDEUP_VX:
case RISCV::VSLIDEDOWN_VI:
case RISCV::VSLIDEDOWN_VX:
case RISCV::VSLIDE1UP_VX:
case RISCV::VFSLIDE1UP_VF:
case RISCV::VSLIDE1DOWN_VX:
case RISCV::VFSLIDE1DOWN_VF:
// Vector Register Gather Instructions
// EEW=SEW. For mask operand, EEW=1.
case RISCV::VRGATHER_VI:
case RISCV::VRGATHER_VV:
case RISCV::VRGATHER_VX:
// Vector Element Index Instruction
case RISCV::VID_V:
// Vector Single-Width Floating-Point Add/Subtract Instructions
case RISCV::VFADD_VF:
case RISCV::VFADD_VV:
case RISCV::VFSUB_VF:
case RISCV::VFSUB_VV:
case RISCV::VFRSUB_VF:
// Vector Single-Width Floating-Point Multiply/Divide Instructions
case RISCV::VFMUL_VF:
case RISCV::VFMUL_VV:
case RISCV::VFDIV_VF:
case RISCV::VFDIV_VV:
case RISCV::VFRDIV_VF:
// Vector Single-Width Floating-Point Fused Multiply-Add Instructions
case RISCV::VFMACC_VV:
case RISCV::VFMACC_VF:
case RISCV::VFNMACC_VV:
case RISCV::VFNMACC_VF:
case RISCV::VFMSAC_VV:
case RISCV::VFMSAC_VF:
case RISCV::VFNMSAC_VV:
case RISCV::VFNMSAC_VF:
case RISCV::VFMADD_VV:
case RISCV::VFMADD_VF:
case RISCV::VFNMADD_VV:
case RISCV::VFNMADD_VF:
case RISCV::VFMSUB_VV:
case RISCV::VFMSUB_VF:
case RISCV::VFNMSUB_VV:
case RISCV::VFNMSUB_VF:
// Vector Floating-Point Square-Root Instruction
case RISCV::VFSQRT_V:
// Vector Floating-Point Reciprocal Square-Root Estimate Instruction
case RISCV::VFRSQRT7_V:
// Vector Floating-Point Reciprocal Estimate Instruction
case RISCV::VFREC7_V:
// Vector Floating-Point MIN/MAX Instructions
case RISCV::VFMIN_VF:
case RISCV::VFMIN_VV:
case RISCV::VFMAX_VF:
case RISCV::VFMAX_VV:
// Vector Floating-Point Sign-Injection Instructions
case RISCV::VFSGNJ_VF:
case RISCV::VFSGNJ_VV:
case RISCV::VFSGNJN_VV:
case RISCV::VFSGNJN_VF:
case RISCV::VFSGNJX_VF:
case RISCV::VFSGNJX_VV:
// Vector Floating-Point Classify Instruction
case RISCV::VFCLASS_V:
// Vector Floating-Point Move Instruction
case RISCV::VFMV_V_F:
// Single-Width Floating-Point/Integer Type-Convert Instructions
case RISCV::VFCVT_XU_F_V:
case RISCV::VFCVT_X_F_V:
case RISCV::VFCVT_RTZ_XU_F_V:
case RISCV::VFCVT_RTZ_X_F_V:
case RISCV::VFCVT_F_XU_V:
case RISCV::VFCVT_F_X_V:
// Vector Floating-Point Merge Instruction
case RISCV::VFMERGE_VFM:
// Vector count population in mask vcpop.m
// vfirst find-first-set mask bit
case RISCV::VCPOP_M:
case RISCV::VFIRST_M:
// Vector Bit-manipulation Instructions (Zvbb)
// Vector And-Not
case RISCV::VANDN_VV:
case RISCV::VANDN_VX:
// Vector Reverse Bits in Elements
case RISCV::VBREV_V:
// Vector Reverse Bits in Bytes
case RISCV::VBREV8_V:
// Vector Reverse Bytes
case RISCV::VREV8_V:
// Vector Count Leading Zeros
case RISCV::VCLZ_V:
// Vector Count Trailing Zeros
case RISCV::VCTZ_V:
// Vector Population Count
case RISCV::VCPOP_V:
// Vector Rotate Left
case RISCV::VROL_VV:
case RISCV::VROL_VX:
// Vector Rotate Right
case RISCV::VROR_VI:
case RISCV::VROR_VV:
case RISCV::VROR_VX:
// Vector Carry-less Multiplication Instructions (Zvbc)
// Vector Carry-less Multiply
case RISCV::VCLMUL_VV:
case RISCV::VCLMUL_VX:
// Vector Carry-less Multiply Return High Half
case RISCV::VCLMULH_VV:
case RISCV::VCLMULH_VX:
return MILog2SEW;
// Vector Widening Shift Left Logical (Zvbb)
case RISCV::VWSLL_VI:
case RISCV::VWSLL_VX:
case RISCV::VWSLL_VV:
// Vector Widening Integer Add/Subtract
// Def uses EEW=2*SEW . Operands use EEW=SEW.
case RISCV::VWADDU_VV:
case RISCV::VWADDU_VX:
case RISCV::VWSUBU_VV:
case RISCV::VWSUBU_VX:
case RISCV::VWADD_VV:
case RISCV::VWADD_VX:
case RISCV::VWSUB_VV:
case RISCV::VWSUB_VX:
// Vector Widening Integer Multiply Instructions
// Destination EEW=2*SEW. Source EEW=SEW.
case RISCV::VWMUL_VV:
case RISCV::VWMUL_VX:
case RISCV::VWMULSU_VV:
case RISCV::VWMULSU_VX:
case RISCV::VWMULU_VV:
case RISCV::VWMULU_VX:
// Vector Widening Integer Multiply-Add Instructions
// Destination EEW=2*SEW. Source EEW=SEW.
// A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
// is then added to the 2*SEW-bit Dest. These instructions never have a
// passthru operand.
case RISCV::VWMACCU_VV:
case RISCV::VWMACCU_VX:
case RISCV::VWMACC_VV:
case RISCV::VWMACC_VX:
case RISCV::VWMACCSU_VV:
case RISCV::VWMACCSU_VX:
case RISCV::VWMACCUS_VX:
// Vector Widening Floating-Point Fused Multiply-Add Instructions
case RISCV::VFWMACC_VF:
case RISCV::VFWMACC_VV:
case RISCV::VFWNMACC_VF:
case RISCV::VFWNMACC_VV:
case RISCV::VFWMSAC_VF:
case RISCV::VFWMSAC_VV:
case RISCV::VFWNMSAC_VF:
case RISCV::VFWNMSAC_VV:
case RISCV::VFWMACCBF16_VV:
case RISCV::VFWMACCBF16_VF:
// Vector Widening Floating-Point Add/Subtract Instructions
// Dest EEW=2*SEW. Source EEW=SEW.
case RISCV::VFWADD_VV:
case RISCV::VFWADD_VF:
case RISCV::VFWSUB_VV:
case RISCV::VFWSUB_VF:
// Vector Widening Floating-Point Multiply
case RISCV::VFWMUL_VF:
case RISCV::VFWMUL_VV:
// Widening Floating-Point/Integer Type-Convert Instructions
case RISCV::VFWCVT_XU_F_V:
case RISCV::VFWCVT_X_F_V:
case RISCV::VFWCVT_RTZ_XU_F_V:
case RISCV::VFWCVT_RTZ_X_F_V:
case RISCV::VFWCVT_F_XU_V:
case RISCV::VFWCVT_F_X_V:
case RISCV::VFWCVT_F_F_V:
case RISCV::VFWCVTBF16_F_F_V:
return IsMODef ? MILog2SEW + 1 : MILog2SEW;
// Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
case RISCV::VWADDU_WV:
case RISCV::VWADDU_WX:
case RISCV::VWSUBU_WV:
case RISCV::VWSUBU_WX:
case RISCV::VWADD_WV:
case RISCV::VWADD_WX:
case RISCV::VWSUB_WV:
case RISCV::VWSUB_WX:
// Vector Widening Floating-Point Add/Subtract Instructions
case RISCV::VFWADD_WF:
case RISCV::VFWADD_WV:
case RISCV::VFWSUB_WF:
case RISCV::VFWSUB_WV: {
bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2
: MO.getOperandNo() == 1;
bool TwoTimes = IsMODef || IsOp1;
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
}
// Vector Integer Extension
case RISCV::VZEXT_VF2:
case RISCV::VSEXT_VF2:
return getIntegerExtensionOperandEEW(2, MI, MO);
case RISCV::VZEXT_VF4:
case RISCV::VSEXT_VF4:
return getIntegerExtensionOperandEEW(4, MI, MO);
case RISCV::VZEXT_VF8:
case RISCV::VSEXT_VF8:
return getIntegerExtensionOperandEEW(8, MI, MO);
// Vector Narrowing Integer Right Shift Instructions
// Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
case RISCV::VNSRL_WX:
case RISCV::VNSRL_WI:
case RISCV::VNSRL_WV:
case RISCV::VNSRA_WI:
case RISCV::VNSRA_WV:
case RISCV::VNSRA_WX:
// Vector Narrowing Fixed-Point Clip Instructions
// Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
case RISCV::VNCLIPU_WI:
case RISCV::VNCLIPU_WV:
case RISCV::VNCLIPU_WX:
case RISCV::VNCLIP_WI:
case RISCV::VNCLIP_WV:
case RISCV::VNCLIP_WX:
// Narrowing Floating-Point/Integer Type-Convert Instructions
case RISCV::VFNCVT_XU_F_W:
case RISCV::VFNCVT_X_F_W:
case RISCV::VFNCVT_RTZ_XU_F_W:
case RISCV::VFNCVT_RTZ_X_F_W:
case RISCV::VFNCVT_F_XU_W:
case RISCV::VFNCVT_F_X_W:
case RISCV::VFNCVT_F_F_W:
case RISCV::VFNCVT_ROD_F_F_W:
case RISCV::VFNCVTBF16_F_F_W: {
assert(!IsTied);
bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
bool TwoTimes = IsOp1;
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
}
// Vector Mask Instructions
// Vector Mask-Register Logical Instructions
// vmsbf.m set-before-first mask bit
// vmsif.m set-including-first mask bit
// vmsof.m set-only-first mask bit
// EEW=1
// We handle the cases when operand is a v0 mask operand above the switch,
// but these instructions may use non-v0 mask operands and need to be handled
// specifically.
case RISCV::VMAND_MM:
case RISCV::VMNAND_MM:
case RISCV::VMANDN_MM:
case RISCV::VMXOR_MM:
case RISCV::VMOR_MM:
case RISCV::VMNOR_MM:
case RISCV::VMORN_MM:
case RISCV::VMXNOR_MM:
case RISCV::VMSBF_M:
case RISCV::VMSIF_M:
case RISCV::VMSOF_M: {
return MILog2SEW;
}
// Vector Compress Instruction
// EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
// before this switch.
case RISCV::VCOMPRESS_VM:
return MO.getOperandNo() == 3 ? 0 : MILog2SEW;
// Vector Iota Instruction
// EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
// before this switch.
case RISCV::VIOTA_M: {
if (IsMODef || MO.getOperandNo() == 1)
return MILog2SEW;
return 0;
}
// Vector Integer Compare Instructions
// Dest EEW=1. Source EEW=SEW.
case RISCV::VMSEQ_VI:
case RISCV::VMSEQ_VV:
case RISCV::VMSEQ_VX:
case RISCV::VMSNE_VI:
case RISCV::VMSNE_VV:
case RISCV::VMSNE_VX:
case RISCV::VMSLTU_VV:
case RISCV::VMSLTU_VX:
case RISCV::VMSLT_VV:
case RISCV::VMSLT_VX:
case RISCV::VMSLEU_VV:
case RISCV::VMSLEU_VI:
case RISCV::VMSLEU_VX:
case RISCV::VMSLE_VV:
case RISCV::VMSLE_VI:
case RISCV::VMSLE_VX:
case RISCV::VMSGTU_VI:
case RISCV::VMSGTU_VX:
case RISCV::VMSGT_VI:
case RISCV::VMSGT_VX:
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
// Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
case RISCV::VMADC_VIM:
case RISCV::VMADC_VVM:
case RISCV::VMADC_VXM:
case RISCV::VMSBC_VVM:
case RISCV::VMSBC_VXM:
// Dest EEW=1. Source EEW=SEW.
case RISCV::VMADC_VV:
case RISCV::VMADC_VI:
case RISCV::VMADC_VX:
case RISCV::VMSBC_VV:
case RISCV::VMSBC_VX:
// 13.13. Vector Floating-Point Compare Instructions
// Dest EEW=1. Source EEW=SEW
case RISCV::VMFEQ_VF:
case RISCV::VMFEQ_VV:
case RISCV::VMFNE_VF:
case RISCV::VMFNE_VV:
case RISCV::VMFLT_VF:
case RISCV::VMFLT_VV:
case RISCV::VMFLE_VF:
case RISCV::VMFLE_VV:
case RISCV::VMFGT_VF:
case RISCV::VMFGE_VF: {
if (IsMODef)
return 0;
return MILog2SEW;
}
// Vector Reduction Operations
// Vector Single-Width Integer Reduction Instructions
case RISCV::VREDAND_VS:
case RISCV::VREDMAX_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS:
// Vector Single-Width Floating-Point Reduction Instructions
case RISCV::VFREDMAX_VS:
case RISCV::VFREDMIN_VS:
case RISCV::VFREDOSUM_VS:
case RISCV::VFREDUSUM_VS: {
return MILog2SEW;
}
// Vector Widening Integer Reduction Instructions
// The Dest and VS1 read only element 0 for the vector register. Return
// 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
case RISCV::VWREDSUM_VS:
case RISCV::VWREDSUMU_VS:
// Vector Widening Floating-Point Reduction Instructions
case RISCV::VFWREDOSUM_VS:
case RISCV::VFWREDUSUM_VS: {
bool TwoTimes = IsMODef || MO.getOperandNo() == 3;
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
}
// Vector Register Gather with 16-bit Index Elements Instruction
// Dest and source data EEW=SEW. Index vector EEW=16.
case RISCV::VRGATHEREI16_VV: {
if (MO.getOperandNo() == 2)
return 4;
return MILog2SEW;
}
default:
return std::nullopt;
}
}
static std::optional<OperandInfo> getOperandInfo(const MachineOperand &MO) {
const MachineInstr &MI = *MO.getParent();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
assert(RVV && "Could not find MI in PseudoTable");
std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO);
if (!Log2EEW)
return std::nullopt;
switch (RVV->BaseInstr) {
// Vector Reduction Operations
// Vector Single-Width Integer Reduction Instructions
// Vector Widening Integer Reduction Instructions
// Vector Widening Floating-Point Reduction Instructions
// The Dest and VS1 only read element 0 of the vector register. Return just
// the EEW for these.
case RISCV::VREDAND_VS:
case RISCV::VREDMAX_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS:
case RISCV::VWREDSUM_VS:
case RISCV::VWREDSUMU_VS:
case RISCV::VFWREDOSUM_VS:
case RISCV::VFWREDUSUM_VS:
if (MO.getOperandNo() != 2)
return OperandInfo(*Log2EEW);
break;
};
// All others have EMUL=EEW/SEW*LMUL
return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), *Log2EEW);
}
static bool isTupleInsertInstr(const MachineInstr &MI);
/// Return true if this optimization should consider MI for VL reduction. This
/// white-list approach simplifies this optimization for instructions that may
/// have more complex semantics with relation to how it uses VL.
static bool isSupportedInstr(const MachineInstr &MI) {
if (MI.isPHI() || MI.isFullCopy() || isTupleInsertInstr(MI))
return true;
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
if (!RVV)
return false;
switch (RVV->BaseInstr) {
// Vector Unit-Stride Instructions
// Vector Strided Instructions
case RISCV::VLM_V:
case RISCV::VLE8_V:
case RISCV::VLSE8_V:
case RISCV::VLE16_V:
case RISCV::VLSE16_V:
case RISCV::VLE32_V:
case RISCV::VLSE32_V:
case RISCV::VLE64_V:
case RISCV::VLSE64_V:
// Vector Indexed Instructions
case RISCV::VLUXEI8_V:
case RISCV::VLOXEI8_V:
case RISCV::VLUXEI16_V:
case RISCV::VLOXEI16_V:
case RISCV::VLUXEI32_V:
case RISCV::VLOXEI32_V:
case RISCV::VLUXEI64_V:
case RISCV::VLOXEI64_V:
// Vector Single-Width Integer Add and Subtract
case RISCV::VADD_VI:
case RISCV::VADD_VV:
case RISCV::VADD_VX:
case RISCV::VSUB_VV:
case RISCV::VSUB_VX:
case RISCV::VRSUB_VI:
case RISCV::VRSUB_VX:
// Vector Bitwise Logical Instructions
// Vector Single-Width Shift Instructions
case RISCV::VAND_VI:
case RISCV::VAND_VV:
case RISCV::VAND_VX:
case RISCV::VOR_VI:
case RISCV::VOR_VV:
case RISCV::VOR_VX:
case RISCV::VXOR_VI:
case RISCV::VXOR_VV:
case RISCV::VXOR_VX:
case RISCV::VSLL_VI:
case RISCV::VSLL_VV:
case RISCV::VSLL_VX:
case RISCV::VSRL_VI:
case RISCV::VSRL_VV:
case RISCV::VSRL_VX:
case RISCV::VSRA_VI:
case RISCV::VSRA_VV:
case RISCV::VSRA_VX:
// Vector Widening Integer Add/Subtract
case RISCV::VWADDU_VV:
case RISCV::VWADDU_VX:
case RISCV::VWSUBU_VV:
case RISCV::VWSUBU_VX:
case RISCV::VWADD_VV:
case RISCV::VWADD_VX:
case RISCV::VWSUB_VV:
case RISCV::VWSUB_VX:
case RISCV::VWADDU_WV:
case RISCV::VWADDU_WX:
case RISCV::VWSUBU_WV:
case RISCV::VWSUBU_WX:
case RISCV::VWADD_WV:
case RISCV::VWADD_WX:
case RISCV::VWSUB_WV:
case RISCV::VWSUB_WX:
// Vector Integer Extension
case RISCV::VZEXT_VF2:
case RISCV::VSEXT_VF2:
case RISCV::VZEXT_VF4:
case RISCV::VSEXT_VF4:
case RISCV::VZEXT_VF8:
case RISCV::VSEXT_VF8:
// Vector Narrowing Integer Right Shift Instructions
case RISCV::VNSRL_WX:
case RISCV::VNSRL_WI:
case RISCV::VNSRL_WV:
case RISCV::VNSRA_WI:
case RISCV::VNSRA_WV:
case RISCV::VNSRA_WX:
// Vector Integer Compare Instructions
case RISCV::VMSEQ_VI:
case RISCV::VMSEQ_VV:
case RISCV::VMSEQ_VX:
case RISCV::VMSNE_VI:
case RISCV::VMSNE_VV:
case RISCV::VMSNE_VX:
case RISCV::VMSLTU_VV:
case RISCV::VMSLTU_VX:
case RISCV::VMSLT_VV:
case RISCV::VMSLT_VX:
case RISCV::VMSLEU_VV:
case RISCV::VMSLEU_VI:
case RISCV::VMSLEU_VX:
case RISCV::VMSLE_VV:
case RISCV::VMSLE_VI:
case RISCV::VMSLE_VX:
case RISCV::VMSGTU_VI:
case RISCV::VMSGTU_VX:
case RISCV::VMSGT_VI:
case RISCV::VMSGT_VX:
// Vector Integer Min/Max Instructions
case RISCV::VMINU_VV:
case RISCV::VMINU_VX:
case RISCV::VMIN_VV:
case RISCV::VMIN_VX:
case RISCV::VMAXU_VV:
case RISCV::VMAXU_VX:
case RISCV::VMAX_VV:
case RISCV::VMAX_VX:
// Vector Single-Width Integer Multiply Instructions
case RISCV::VMUL_VV:
case RISCV::VMUL_VX:
case RISCV::VMULH_VV:
case RISCV::VMULH_VX:
case RISCV::VMULHU_VV:
case RISCV::VMULHU_VX:
case RISCV::VMULHSU_VV:
case RISCV::VMULHSU_VX:
// Vector Integer Divide Instructions
case RISCV::VDIVU_VV:
case RISCV::VDIVU_VX:
case RISCV::VDIV_VV:
case RISCV::VDIV_VX:
case RISCV::VREMU_VV:
case RISCV::VREMU_VX:
case RISCV::VREM_VV:
case RISCV::VREM_VX:
// Vector Widening Integer Multiply Instructions
case RISCV::VWMUL_VV:
case RISCV::VWMUL_VX:
case RISCV::VWMULSU_VV:
case RISCV::VWMULSU_VX:
case RISCV::VWMULU_VV:
case RISCV::VWMULU_VX:
// Vector Single-Width Integer Multiply-Add Instructions
case RISCV::VMACC_VV:
case RISCV::VMACC_VX:
case RISCV::VNMSAC_VV:
case RISCV::VNMSAC_VX:
case RISCV::VMADD_VV:
case RISCV::VMADD_VX:
case RISCV::VNMSUB_VV:
case RISCV::VNMSUB_VX:
// Vector Integer Merge Instructions
case RISCV::VMERGE_VIM:
case RISCV::VMERGE_VVM:
case RISCV::VMERGE_VXM:
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
case RISCV::VADC_VIM:
case RISCV::VADC_VVM:
case RISCV::VADC_VXM:
case RISCV::VMADC_VIM:
case RISCV::VMADC_VVM:
case RISCV::VMADC_VXM:
case RISCV::VSBC_VVM:
case RISCV::VSBC_VXM:
case RISCV::VMSBC_VVM:
case RISCV::VMSBC_VXM:
case RISCV::VMADC_VV:
case RISCV::VMADC_VI:
case RISCV::VMADC_VX:
case RISCV::VMSBC_VV:
case RISCV::VMSBC_VX:
// Vector Widening Integer Multiply-Add Instructions
case RISCV::VWMACCU_VV:
case RISCV::VWMACCU_VX:
case RISCV::VWMACC_VV:
case RISCV::VWMACC_VX:
case RISCV::VWMACCSU_VV:
case RISCV::VWMACCSU_VX:
case RISCV::VWMACCUS_VX:
// Vector Integer Move Instructions
case RISCV::VMV_V_I:
case RISCV::VMV_V_X:
case RISCV::VMV_V_V:
// Vector Single-Width Saturating Add and Subtract
case RISCV::VSADDU_VV:
case RISCV::VSADDU_VX:
case RISCV::VSADDU_VI:
case RISCV::VSADD_VV:
case RISCV::VSADD_VX:
case RISCV::VSADD_VI:
case RISCV::VSSUBU_VV:
case RISCV::VSSUBU_VX:
case RISCV::VSSUB_VV:
case RISCV::VSSUB_VX:
// Vector Single-Width Averaging Add and Subtract
case RISCV::VAADDU_VV:
case RISCV::VAADDU_VX:
case RISCV::VAADD_VV:
case RISCV::VAADD_VX:
case RISCV::VASUBU_VV:
case RISCV::VASUBU_VX:
case RISCV::VASUB_VV:
case RISCV::VASUB_VX:
// Vector Single-Width Fractional Multiply with Rounding and Saturation
case RISCV::VSMUL_VV:
case RISCV::VSMUL_VX:
// Vector Single-Width Scaling Shift Instructions
case RISCV::VSSRL_VV:
case RISCV::VSSRL_VX:
case RISCV::VSSRL_VI:
case RISCV::VSSRA_VV:
case RISCV::VSSRA_VX:
case RISCV::VSSRA_VI:
// Vector Narrowing Fixed-Point Clip Instructions
case RISCV::VNCLIPU_WV:
case RISCV::VNCLIPU_WX:
case RISCV::VNCLIPU_WI:
case RISCV::VNCLIP_WV:
case RISCV::VNCLIP_WX:
case RISCV::VNCLIP_WI:
// Vector Bit-manipulation Instructions (Zvbb)
// Vector And-Not
case RISCV::VANDN_VV:
case RISCV::VANDN_VX:
// Vector Reverse Bits in Elements
case RISCV::VBREV_V:
// Vector Reverse Bits in Bytes
case RISCV::VBREV8_V:
// Vector Reverse Bytes
case RISCV::VREV8_V:
// Vector Count Leading Zeros
case RISCV::VCLZ_V:
// Vector Count Trailing Zeros
case RISCV::VCTZ_V:
// Vector Population Count
case RISCV::VCPOP_V:
// Vector Rotate Left
case RISCV::VROL_VV:
case RISCV::VROL_VX:
// Vector Rotate Right
case RISCV::VROR_VI:
case RISCV::VROR_VV:
case RISCV::VROR_VX:
// Vector Widening Shift Left Logical
case RISCV::VWSLL_VI:
case RISCV::VWSLL_VX:
case RISCV::VWSLL_VV:
// Vector Carry-less Multiplication Instructions (Zvbc)
// Vector Carry-less Multiply
case RISCV::VCLMUL_VV:
case RISCV::VCLMUL_VX:
// Vector Carry-less Multiply Return High Half
case RISCV::VCLMULH_VV:
case RISCV::VCLMULH_VX:
// Vector Mask Instructions
// Vector Mask-Register Logical Instructions
// vmsbf.m set-before-first mask bit
// vmsif.m set-including-first mask bit
// vmsof.m set-only-first mask bit
// Vector Iota Instruction
// Vector Element Index Instruction
case RISCV::VMAND_MM:
case RISCV::VMNAND_MM:
case RISCV::VMANDN_MM:
case RISCV::VMXOR_MM:
case RISCV::VMOR_MM:
case RISCV::VMNOR_MM:
case RISCV::VMORN_MM:
case RISCV::VMXNOR_MM:
case RISCV::VMSBF_M:
case RISCV::VMSIF_M:
case RISCV::VMSOF_M:
case RISCV::VIOTA_M:
case RISCV::VID_V:
// Vector Slide Instructions
case RISCV::VSLIDEUP_VX:
case RISCV::VSLIDEUP_VI:
case RISCV::VSLIDEDOWN_VX:
case RISCV::VSLIDEDOWN_VI:
case RISCV::VSLIDE1UP_VX:
case RISCV::VFSLIDE1UP_VF:
// Vector Register Gather Instructions
case RISCV::VRGATHER_VI:
case RISCV::VRGATHER_VV:
case RISCV::VRGATHER_VX:
case RISCV::VRGATHEREI16_VV:
// Vector Single-Width Floating-Point Add/Subtract Instructions
case RISCV::VFADD_VF:
case RISCV::VFADD_VV:
case RISCV::VFSUB_VF:
case RISCV::VFSUB_VV:
case RISCV::VFRSUB_VF:
// Vector Widening Floating-Point Add/Subtract Instructions
case RISCV::VFWADD_VV:
case RISCV::VFWADD_VF:
case RISCV::VFWSUB_VV:
case RISCV::VFWSUB_VF:
case RISCV::VFWADD_WF:
case RISCV::VFWADD_WV:
case RISCV::VFWSUB_WF:
case RISCV::VFWSUB_WV:
// Vector Single-Width Floating-Point Multiply/Divide Instructions
case RISCV::VFMUL_VF:
case RISCV::VFMUL_VV:
case RISCV::VFDIV_VF:
case RISCV::VFDIV_VV:
case RISCV::VFRDIV_VF:
// Vector Widening Floating-Point Multiply
case RISCV::VFWMUL_VF:
case RISCV::VFWMUL_VV:
// Vector Single-Width Floating-Point Fused Multiply-Add Instructions
case RISCV::VFMACC_VV:
case RISCV::VFMACC_VF:
case RISCV::VFNMACC_VV:
case RISCV::VFNMACC_VF:
case RISCV::VFMSAC_VV:
case RISCV::VFMSAC_VF:
case RISCV::VFNMSAC_VV:
case RISCV::VFNMSAC_VF:
case RISCV::VFMADD_VV:
case RISCV::VFMADD_VF:
case RISCV::VFNMADD_VV:
case RISCV::VFNMADD_VF:
case RISCV::VFMSUB_VV:
case RISCV::VFMSUB_VF:
case RISCV::VFNMSUB_VV:
case RISCV::VFNMSUB_VF:
// Vector Widening Floating-Point Fused Multiply-Add Instructions
case RISCV::VFWMACC_VV:
case RISCV::VFWMACC_VF:
case RISCV::VFWNMACC_VV:
case RISCV::VFWNMACC_VF:
case RISCV::VFWMSAC_VV:
case RISCV::VFWMSAC_VF:
case RISCV::VFWNMSAC_VV:
case RISCV::VFWNMSAC_VF:
case RISCV::VFWMACCBF16_VV:
case RISCV::VFWMACCBF16_VF:
// Vector Floating-Point Square-Root Instruction
case RISCV::VFSQRT_V:
// Vector Floating-Point Reciprocal Square-Root Estimate Instruction
case RISCV::VFRSQRT7_V:
// Vector Floating-Point Reciprocal Estimate Instruction
case RISCV::VFREC7_V:
// Vector Floating-Point MIN/MAX Instructions
case RISCV::VFMIN_VF:
case RISCV::VFMIN_VV:
case RISCV::VFMAX_VF:
case RISCV::VFMAX_VV:
// Vector Floating-Point Sign-Injection Instructions
case RISCV::VFSGNJ_VF:
case RISCV::VFSGNJ_VV:
case RISCV::VFSGNJN_VV:
case RISCV::VFSGNJN_VF:
case RISCV::VFSGNJX_VF:
case RISCV::VFSGNJX_VV:
// Vector Floating-Point Compare Instructions
case RISCV::VMFEQ_VF:
case RISCV::VMFEQ_VV:
case RISCV::VMFNE_VF:
case RISCV::VMFNE_VV:
case RISCV::VMFLT_VF:
case RISCV::VMFLT_VV:
case RISCV::VMFLE_VF:
case RISCV::VMFLE_VV:
case RISCV::VMFGT_VF:
case RISCV::VMFGE_VF:
// Vector Floating-Point Classify Instruction
case RISCV::VFCLASS_V:
// Vector Floating-Point Merge Instruction
case RISCV::VFMERGE_VFM:
// Vector Floating-Point Move Instruction
case RISCV::VFMV_V_F:
// Single-Width Floating-Point/Integer Type-Convert Instructions
case RISCV::VFCVT_XU_F_V:
case RISCV::VFCVT_X_F_V:
case RISCV::VFCVT_RTZ_XU_F_V:
case RISCV::VFCVT_RTZ_X_F_V:
case RISCV::VFCVT_F_XU_V:
case RISCV::VFCVT_F_X_V:
// Widening Floating-Point/Integer Type-Convert Instructions
case RISCV::VFWCVT_XU_F_V:
case RISCV::VFWCVT_X_F_V:
case RISCV::VFWCVT_RTZ_XU_F_V:
case RISCV::VFWCVT_RTZ_X_F_V:
case RISCV::VFWCVT_F_XU_V:
case RISCV::VFWCVT_F_X_V:
case RISCV::VFWCVT_F_F_V:
case RISCV::VFWCVTBF16_F_F_V:
// Narrowing Floating-Point/Integer Type-Convert Instructions
case RISCV::VFNCVT_XU_F_W:
case RISCV::VFNCVT_X_F_W:
case RISCV::VFNCVT_RTZ_XU_F_W:
case RISCV::VFNCVT_RTZ_X_F_W:
case RISCV::VFNCVT_F_XU_W:
case RISCV::VFNCVT_F_X_W:
case RISCV::VFNCVT_F_F_W:
case RISCV::VFNCVT_ROD_F_F_W:
case RISCV::VFNCVTBF16_F_F_W:
return true;
}
return false;
}
/// Return true if MO is a vector operand but is used as a scalar operand.
static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) {
const MachineInstr *MI = MO.getParent();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI->getOpcode());
if (!RVV)
return false;
switch (RVV->BaseInstr) {
// Reductions only use vs1[0] of vs1
case RISCV::VREDAND_VS:
case RISCV::VREDMAX_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS:
case RISCV::VWREDSUM_VS:
case RISCV::VWREDSUMU_VS:
case RISCV::VFREDMAX_VS:
case RISCV::VFREDMIN_VS:
case RISCV::VFREDOSUM_VS:
case RISCV::VFREDUSUM_VS:
case RISCV::VFWREDOSUM_VS:
case RISCV::VFWREDUSUM_VS:
return MO.getOperandNo() == 3;
case RISCV::VMV_X_S:
case RISCV::VFMV_F_S:
return MO.getOperandNo() == 1;
default:
return false;
}
}
bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
const MCInstrDesc &Desc = MI.getDesc();
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags))
return false;
if (MI.getNumExplicitDefs() != 1)
return false;
// Some instructions have implicit defs e.g. $vxsat. If they might be read
// later then we can't reduce VL.
if (!MI.allImplicitDefsAreDead()) {
LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n");
return false;
}
if (MI.mayRaiseFPException()) {
LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
return false;
}
for (const MachineMemOperand *MMO : MI.memoperands()) {
if (MMO->isVolatile()) {
LLVM_DEBUG(dbgs() << "Not a candidate because contains volatile MMO\n");
return false;
}
}
// Some instructions that produce vectors have semantics that make it more
// difficult to determine whether the VL can be reduced. For example, some
// instructions, such as reductions, may write lanes past VL to a scalar
// register. Other instructions, such as some loads or stores, may write
// lower lanes using data from higher lanes. There may be other complex
// semantics not mentioned here that make it hard to determine whether
// the VL can be optimized. As a result, a white-list of supported
// instructions is used. Over time, more instructions can be supported
// upon careful examination of their semantics under the logic in this
// optimization.
// TODO: Use a better approach than a white-list, such as adding
// properties to instructions using something like TSFlags.
if (!isSupportedInstr(MI)) {
LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction: "
<< MI);
return false;
}
assert(!RISCVII::elementsDependOnVL(
TII->get(RISCV::getRVVMCOpcode(MI.getOpcode())).TSFlags) &&
"Instruction shouldn't be supported if elements depend on VL");
assert(RISCVRI::isVRegClass(
MRI->getRegClass(MI.getOperand(0).getReg())->TSFlags) &&
"All supported instructions produce a vector register result");
LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
return true;
}
/// Given a vslidedown.vx like:
///
/// %slideamt = ADDI %x, -1
/// %v = PseudoVSLIDEDOWN_VX %passthru, %src, %slideamt, avl=1
///
/// %v will only read the first %slideamt + 1 lanes of %src, which = %x.
/// This is a common case when lowering extractelement.
///
/// Note that if %x is 0, %slideamt will be all ones. In this case %src will be
/// completely slid down and none of its lanes will be read (since %slideamt is
/// greater than the largest VLMAX of 65536) so we can demand any minimum VL.
static std::optional<DemandedVL>
getMinimumVLForVSLIDEDOWN_VX(const MachineOperand &UserOp,
const MachineRegisterInfo *MRI) {
const MachineInstr &MI = *UserOp.getParent();
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VSLIDEDOWN_VX)
return std::nullopt;
// We're looking at what lanes are used from the src operand.
if (UserOp.getOperandNo() != 2)
return std::nullopt;
// For now, the AVL must be 1.
const MachineOperand &AVL = MI.getOperand(4);
if (!AVL.isImm() || AVL.getImm() != 1)
return std::nullopt;
// The slide amount must be %x - 1.
const MachineOperand &SlideAmt = MI.getOperand(3);
if (!SlideAmt.getReg().isVirtual())
return std::nullopt;
MachineInstr *SlideAmtDef = MRI->getUniqueVRegDef(SlideAmt.getReg());
if (SlideAmtDef->getOpcode() != RISCV::ADDI ||
SlideAmtDef->getOperand(2).getImm() != -AVL.getImm() ||
!SlideAmtDef->getOperand(1).getReg().isVirtual())
return std::nullopt;
return SlideAmtDef->getOperand(1);
}
DemandedVL
RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
const MachineInstr &UserMI = *UserOp.getParent();
const MCInstrDesc &Desc = UserMI.getDesc();
if (UserMI.isPHI() || UserMI.isFullCopy() || isTupleInsertInstr(UserMI))
return DemandedVLs.lookup(&UserMI);
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
" use VLMAX\n");
return DemandedVL::vlmax();
}
if (auto VL = getMinimumVLForVSLIDEDOWN_VX(UserOp, MRI))
return *VL;
if (RISCVII::readsPastVL(
TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
return DemandedVL::vlmax();
}
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
// Looking for an immediate or a register VL that isn't X0.
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
"Did not expect X0 VL");
// If the user is a passthru it will read the elements past VL, so
// abort if any of the elements past VL are demanded.
if (UserOp.isTied()) {
assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
if (!RISCV::isVLKnownLE(DemandedVLs.lookup(&UserMI).VL, VLOp)) {
LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
"instruction with demanded tail\n");
return DemandedVL::vlmax();
}
}
// Instructions like reductions may use a vector register as a scalar
// register. In this case, we should treat it as only reading the first lane.
if (isVectorOpUsedAsScalarOp(UserOp)) {
LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n");
return MachineOperand::CreateImm(1);
}
// If we know the demanded VL of UserMI, then we can reduce the VL it
// requires.
if (RISCV::isVLKnownLE(DemandedVLs.lookup(&UserMI).VL, VLOp))
return DemandedVLs.lookup(&UserMI);
return VLOp;
}
/// Return true if MI is an instruction used for assembling registers
/// for segmented store instructions, namely, RISCVISD::TUPLE_INSERT.
/// Currently it's lowered to INSERT_SUBREG.
static bool isTupleInsertInstr(const MachineInstr &MI) {
if (!MI.isInsertSubreg())
return false;
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
const TargetRegisterClass *DstRC = MRI.getRegClass(MI.getOperand(0).getReg());
const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
if (!RISCVRI::isVRegClass(DstRC->TSFlags))
return false;
unsigned NF = RISCVRI::getNF(DstRC->TSFlags);
if (NF < 2)
return false;
// Check whether INSERT_SUBREG has the correct subreg index for tuple inserts.
auto VLMul = RISCVRI::getLMul(DstRC->TSFlags);
unsigned SubRegIdx = MI.getOperand(3).getImm();
[[maybe_unused]] auto [LMul, IsFractional] = RISCVVType::decodeVLMUL(VLMul);
assert(!IsFractional && "unexpected LMUL for tuple register classes");
return TRI->getSubRegIdxSize(SubRegIdx) == RISCV::RVVBitsPerBlock * LMul;
}
static bool isSegmentedStoreInstr(const MachineInstr &MI) {
switch (RISCV::getRVVMCOpcode(MI.getOpcode())) {
case VSSEG_CASES(8):
case VSSSEG_CASES(8):
case VSUXSEG_CASES(8):
case VSOXSEG_CASES(8):
case VSSEG_CASES(16):
case VSSSEG_CASES(16):
case VSUXSEG_CASES(16):
case VSOXSEG_CASES(16):
case VSSEG_CASES(32):
case VSSSEG_CASES(32):
case VSUXSEG_CASES(32):
case VSOXSEG_CASES(32):
case VSSEG_CASES(64):
case VSSSEG_CASES(64):
case VSUXSEG_CASES(64):
case VSOXSEG_CASES(64):
return true;
default:
return false;
}
}
bool RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
if (MI.isPHI() || MI.isFullCopy() || isTupleInsertInstr(MI))
return true;
SmallSetVector<MachineOperand *, 8> OpWorklist;
SmallPtrSet<const MachineInstr *, 4> PHISeen;
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg()))
OpWorklist.insert(&UserOp);
while (!OpWorklist.empty()) {
MachineOperand &UserOp = *OpWorklist.pop_back_val();
const MachineInstr &UserMI = *UserOp.getParent();
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
if (UserMI.isFullCopy() && UserMI.getOperand(0).getReg().isVirtual()) {
LLVM_DEBUG(dbgs() << " Peeking through uses of COPY\n");
OpWorklist.insert_range(llvm::make_pointer_range(
MRI->use_operands(UserMI.getOperand(0).getReg())));
continue;
}
if (isTupleInsertInstr(UserMI)) {
LLVM_DEBUG(dbgs().indent(4) << "Peeking through uses of INSERT_SUBREG\n");
for (MachineOperand &UseOp :
MRI->use_operands(UserMI.getOperand(0).getReg())) {
const MachineInstr &CandidateMI = *UseOp.getParent();
// We should not propagate the VL if the user is not a segmented store
// or another INSERT_SUBREG, since VL just works differently
// between segmented operations (per-field) v.s. other RVV ops (on the
// whole register group).
if (!isTupleInsertInstr(CandidateMI) &&
!isSegmentedStoreInstr(CandidateMI))
return false;
OpWorklist.insert(&UseOp);
}
continue;
}
if (UserMI.isPHI()) {
// Don't follow PHI cycles
if (!PHISeen.insert(&UserMI).second)
continue;
LLVM_DEBUG(dbgs() << " Peeking through uses of PHI\n");
OpWorklist.insert_range(llvm::make_pointer_range(
MRI->use_operands(UserMI.getOperand(0).getReg())));
continue;
}
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
return false;
}
std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp);
std::optional<OperandInfo> ProducerInfo = getOperandInfo(MI.getOperand(0));
if (!ConsumerInfo || !ProducerInfo) {
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
return false;
}
if (!OperandInfo::areCompatible(*ProducerInfo, *ConsumerInfo)) {
LLVM_DEBUG(
dbgs()
<< " Abort due to incompatible information for EMUL or EEW.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
return false;
}
}
return true;
}
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI,
MachineOperand CommonVL) const {
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI);
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
MachineOperand &VLOp = MI.getOperand(VLOpNum);
// If the VL is 1, then there is no need to reduce it. This is an
// optimization, not needed to preserve correctness.
if (VLOp.isImm() && VLOp.getImm() == 1) {
LLVM_DEBUG(dbgs() << " Abort due to VL == 1, no point in reducing.\n");
return false;
}
assert((CommonVL.isImm() || CommonVL.getReg().isVirtual()) &&
"Expected VL to be an Imm or virtual Reg");
// If the VL is defined by a vleff that doesn't dominate MI, try using the
// vleff's AVL. It will be greater than or equal to the output VL.
if (CommonVL.isReg()) {
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL.getReg());
if (RISCVInstrInfo::isFaultOnlyFirstLoad(*VLMI) &&
!MDT->dominates(VLMI, &MI))
CommonVL = VLMI->getOperand(RISCVII::getVLOpNum(VLMI->getDesc()));
}
if (!RISCV::isVLKnownLE(CommonVL, VLOp)) {
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
return false;
}
if (CommonVL.isIdenticalTo(VLOp)) {
LLVM_DEBUG(
dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
return false;
}
if (CommonVL.isImm()) {
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
<< CommonVL.getImm() << " for " << MI << "\n");
VLOp.ChangeToImmediate(CommonVL.getImm());
return true;
}
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL.getReg());
if (!MDT->dominates(VLMI, &MI)) {
LLVM_DEBUG(dbgs() << " Abort due to VL not dominating.\n");
return false;
}
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
<< printReg(CommonVL.getReg(), MRI->getTargetRegisterInfo())
<< " for " << MI << "\n");
// All our checks passed. We can reduce VL.
VLOp.ChangeToRegister(CommonVL.getReg(), false);
MRI->constrainRegClass(CommonVL.getReg(), &RISCV::GPRNoX0RegClass);
return true;
}
static bool isPhysical(const MachineOperand &MO) {
return MO.isReg() && MO.getReg().isPhysical();
}
/// Look through \p MI's operands and propagate what it demands to its uses.
void RISCVVLOptimizer::transfer(const MachineInstr &MI) {
if (!isSupportedInstr(MI) || !checkUsers(MI) || any_of(MI.defs(), isPhysical))
DemandedVLs[&MI] = DemandedVL::vlmax();
for (const MachineOperand &MO : virtual_vec_uses(MI)) {
const MachineInstr *Def = MRI->getVRegDef(MO.getReg());
DemandedVL Prev = DemandedVLs[Def];
DemandedVLs[Def] = DemandedVLs[Def].max(getMinimumVLForUser(MO));
if (DemandedVLs[Def] != Prev)
Worklist.insert(Def);
}
}
bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()))
return false;
MRI = &MF.getRegInfo();
MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
if (!ST.hasVInstructions())
return false;
TII = ST.getInstrInfo();
assert(DemandedVLs.empty());
// For each instruction that defines a vector, propagate the VL it
// uses to its inputs.
for (MachineBasicBlock *MBB : post_order(&MF)) {
assert(MDT->isReachableFromEntry(MBB));
for (MachineInstr &MI : reverse(*MBB))
if (!MI.isDebugInstr())
Worklist.insert(&MI);
}
while (!Worklist.empty()) {
const MachineInstr *MI = Worklist.front();
Worklist.remove(MI);
transfer(*MI);
}
// Then go through and see if we can reduce the VL of any instructions to
// only what's demanded.
bool MadeChange = false;
for (auto &[MI, VL] : DemandedVLs) {
assert(MDT->isReachableFromEntry(MI->getParent()));
if (!isCandidate(*MI))
continue;
if (!tryReduceVL(*const_cast<MachineInstr *>(MI), VL.VL))
continue;
MadeChange = true;
}
DemandedVLs.clear();
return MadeChange;
}