llvm-project/llvm/lib/Target/RISCV/RISCVVLOptimizer.cpp
Luke Lau e261f2895f
[RISCV] Add TSFlag for reading past VL behaviour. NFCI (#149704)
Currently we have a switch statement that checks if a vector instruction
may read elements past VL. However it currently doesn't account for
instructions in vendor extensions.

Handling all possible vendor instructions will result in quite a lot of
opcodes being added, so I've created a new TSFlag that we can declare in
TableGen, and added it to the existing instruction definitions.

I've tried to be conservative as possible here: All SiFive vendor vector
instructions should be covered by the flag, as well as all of
XRivosVizip, and ri.vextract from XRivosVisni.

For now this should be NFC because coincidentally, these instructions
aren't handled in getOperandInfo, so RISCVVLOptimizer should currently
avoid touching them despite them being liberally handled in
getMinimumVLForUser.

However in an upcoming patch we'll need to also bail in
getMinimumVLForUser, so this prepares for it.
2025-08-15 01:19:03 +00:00

1546 lines
46 KiB
C++

//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//
// This pass reduces the VL where possible at the MI level, before VSETVLI
// instructions are inserted.
//
// The purpose of this optimization is to make the VL argument, for instructions
// that have a VL argument, as small as possible. This is implemented by
// visiting each instruction in reverse order and checking that if it has a VL
// argument, whether the VL can be reduced.
//
//===---------------------------------------------------------------------===//
#include "RISCV.h"
#include "RISCVSubtarget.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/InitializePasses.h"
using namespace llvm;
#define DEBUG_TYPE "riscv-vl-optimizer"
#define PASS_NAME "RISC-V VL Optimizer"
namespace {
class RISCVVLOptimizer : public MachineFunctionPass {
const MachineRegisterInfo *MRI;
const MachineDominatorTree *MDT;
const TargetInstrInfo *TII;
public:
static char ID;
RISCVVLOptimizer() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
AU.addRequired<MachineDominatorTreeWrapperPass>();
MachineFunctionPass::getAnalysisUsage(AU);
}
StringRef getPassName() const override { return PASS_NAME; }
private:
std::optional<MachineOperand>
getMinimumVLForUser(const MachineOperand &UserOp) const;
/// Returns the largest common VL MachineOperand that may be used to optimize
/// MI. Returns std::nullopt if it failed to find a suitable VL.
std::optional<MachineOperand> checkUsers(const MachineInstr &MI) const;
bool tryReduceVL(MachineInstr &MI) const;
bool isCandidate(const MachineInstr &MI) const;
/// For a given instruction, records what elements of it are demanded by
/// downstream users.
DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs;
};
/// Represents the EMUL and EEW of a MachineOperand.
struct OperandInfo {
// Represent as 1,2,4,8, ... and fractional indicator. This is because
// EMUL can take on values that don't map to RISCVVType::VLMUL values exactly.
// For example, a mask operand can have an EMUL less than MF8.
// If nullopt, then EMUL isn't used (i.e. only a single scalar is read).
std::optional<std::pair<unsigned, bool>> EMUL;
unsigned Log2EEW;
OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW)
: EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {}
OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
: EMUL(EMUL), Log2EEW(Log2EEW) {}
OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}
OperandInfo() = delete;
/// Return true if the EMUL and EEW produced by \p Def is compatible with the
/// EMUL and EEW used by \p User.
static bool areCompatible(const OperandInfo &Def, const OperandInfo &User) {
if (Def.Log2EEW != User.Log2EEW)
return false;
if (User.EMUL && Def.EMUL != User.EMUL)
return false;
return true;
}
void print(raw_ostream &OS) const {
if (EMUL) {
OS << "EMUL: m";
if (EMUL->second)
OS << "f";
OS << EMUL->first;
} else
OS << "EMUL: none\n";
OS << ", EEW: " << (1 << Log2EEW);
}
};
} // end anonymous namespace
char RISCVVLOptimizer::ID = 0;
INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
FunctionPass *llvm::createRISCVVLOptimizerPass() {
return new RISCVVLOptimizer();
}
LLVM_ATTRIBUTE_UNUSED
static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
OI.print(OS);
return OS;
}
LLVM_ATTRIBUTE_UNUSED
static raw_ostream &operator<<(raw_ostream &OS,
const std::optional<OperandInfo> &OI) {
if (OI)
OI->print(OS);
else
OS << "nullopt";
return OS;
}
/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
/// SEW are from the TSFlags of MI.
static std::pair<unsigned, bool>
getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
// Mask instructions will have 0 as the SEW operand. But the LMUL of these
// instructions is calculated is as if the SEW operand was 3 (e8).
if (MILog2SEW == 0)
MILog2SEW = 3;
unsigned MISEW = 1 << MILog2SEW;
unsigned EEW = 1 << Log2EEW;
// Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
// to put fraction in simplest form.
unsigned Num = EEW, Denom = MISEW;
int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
: std::gcd(Num * MILMUL, Denom);
Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
}
/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
/// SEW comes from TSFlags of MI.
static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
const MachineInstr &MI,
const MachineOperand &MO) {
unsigned MILog2SEW =
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
if (MO.getOperandNo() == 0)
return MILog2SEW;
unsigned MISEW = 1 << MILog2SEW;
unsigned EEW = MISEW / Factor;
unsigned Log2EEW = Log2_32(EEW);
return Log2EEW;
}
static std::optional<unsigned>
getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
const MachineInstr &MI = *MO.getParent();
const MCInstrDesc &Desc = MI.getDesc();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
assert(RVV && "Could not find MI in PseudoTable");
// MI has a SEW associated with it. The RVV specification defines
// the EEW of each operand and definition in relation to MI.SEW.
unsigned MILog2SEW = MI.getOperand(RISCVII::getSEWOpNum(Desc)).getImm();
const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc);
const bool IsTied = RISCVII::isTiedPseudo(Desc.TSFlags);
bool IsMODef = MO.getOperandNo() == 0 ||
(HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs());
// All mask operands have EEW=1
const MCOperandInfo &Info = Desc.operands()[MO.getOperandNo()];
if (Info.OperandType == MCOI::OPERAND_REGISTER &&
Info.RegClass == RISCV::VMV0RegClassID)
return 0;
// switch against BaseInstr to reduce number of cases that need to be
// considered.
switch (RVV->BaseInstr) {
// 6. Configuration-Setting Instructions
// Configuration setting instructions do not read or write vector registers
case RISCV::VSETIVLI:
case RISCV::VSETVL:
case RISCV::VSETVLI:
llvm_unreachable("Configuration setting instructions do not read or write "
"vector registers");
// Vector Loads and Stores
// Vector Unit-Stride Instructions
// Vector Strided Instructions
/// Dest EEW encoded in the instruction
case RISCV::VLM_V:
case RISCV::VSM_V:
return 0;
case RISCV::VLE8_V:
case RISCV::VSE8_V:
case RISCV::VLSE8_V:
case RISCV::VSSE8_V:
return 3;
case RISCV::VLE16_V:
case RISCV::VSE16_V:
case RISCV::VLSE16_V:
case RISCV::VSSE16_V:
return 4;
case RISCV::VLE32_V:
case RISCV::VSE32_V:
case RISCV::VLSE32_V:
case RISCV::VSSE32_V:
return 5;
case RISCV::VLE64_V:
case RISCV::VSE64_V:
case RISCV::VLSE64_V:
case RISCV::VSSE64_V:
return 6;
// Vector Indexed Instructions
// vs(o|u)xei<eew>.v
// Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>.
case RISCV::VLUXEI8_V:
case RISCV::VLOXEI8_V:
case RISCV::VSUXEI8_V:
case RISCV::VSOXEI8_V: {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 3;
}
case RISCV::VLUXEI16_V:
case RISCV::VLOXEI16_V:
case RISCV::VSUXEI16_V:
case RISCV::VSOXEI16_V: {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 4;
}
case RISCV::VLUXEI32_V:
case RISCV::VLOXEI32_V:
case RISCV::VSUXEI32_V:
case RISCV::VSOXEI32_V: {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 5;
}
case RISCV::VLUXEI64_V:
case RISCV::VLOXEI64_V:
case RISCV::VSUXEI64_V:
case RISCV::VSOXEI64_V: {
if (MO.getOperandNo() == 0)
return MILog2SEW;
return 6;
}
// Vector Integer Arithmetic Instructions
// Vector Single-Width Integer Add and Subtract
case RISCV::VADD_VI:
case RISCV::VADD_VV:
case RISCV::VADD_VX:
case RISCV::VSUB_VV:
case RISCV::VSUB_VX:
case RISCV::VRSUB_VI:
case RISCV::VRSUB_VX:
// Vector Bitwise Logical Instructions
// Vector Single-Width Shift Instructions
// EEW=SEW.
case RISCV::VAND_VI:
case RISCV::VAND_VV:
case RISCV::VAND_VX:
case RISCV::VOR_VI:
case RISCV::VOR_VV:
case RISCV::VOR_VX:
case RISCV::VXOR_VI:
case RISCV::VXOR_VV:
case RISCV::VXOR_VX:
case RISCV::VSLL_VI:
case RISCV::VSLL_VV:
case RISCV::VSLL_VX:
case RISCV::VSRL_VI:
case RISCV::VSRL_VV:
case RISCV::VSRL_VX:
case RISCV::VSRA_VI:
case RISCV::VSRA_VV:
case RISCV::VSRA_VX:
// Vector Integer Min/Max Instructions
// EEW=SEW.
case RISCV::VMINU_VV:
case RISCV::VMINU_VX:
case RISCV::VMIN_VV:
case RISCV::VMIN_VX:
case RISCV::VMAXU_VV:
case RISCV::VMAXU_VX:
case RISCV::VMAX_VV:
case RISCV::VMAX_VX:
// Vector Single-Width Integer Multiply Instructions
// Source and Dest EEW=SEW.
case RISCV::VMUL_VV:
case RISCV::VMUL_VX:
case RISCV::VMULH_VV:
case RISCV::VMULH_VX:
case RISCV::VMULHU_VV:
case RISCV::VMULHU_VX:
case RISCV::VMULHSU_VV:
case RISCV::VMULHSU_VX:
// Vector Integer Divide Instructions
// EEW=SEW.
case RISCV::VDIVU_VV:
case RISCV::VDIVU_VX:
case RISCV::VDIV_VV:
case RISCV::VDIV_VX:
case RISCV::VREMU_VV:
case RISCV::VREMU_VX:
case RISCV::VREM_VV:
case RISCV::VREM_VX:
// Vector Single-Width Integer Multiply-Add Instructions
// EEW=SEW.
case RISCV::VMACC_VV:
case RISCV::VMACC_VX:
case RISCV::VNMSAC_VV:
case RISCV::VNMSAC_VX:
case RISCV::VMADD_VV:
case RISCV::VMADD_VX:
case RISCV::VNMSUB_VV:
case RISCV::VNMSUB_VX:
// Vector Integer Merge Instructions
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
// EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
// before this switch.
case RISCV::VMERGE_VIM:
case RISCV::VMERGE_VVM:
case RISCV::VMERGE_VXM:
case RISCV::VADC_VIM:
case RISCV::VADC_VVM:
case RISCV::VADC_VXM:
case RISCV::VSBC_VVM:
case RISCV::VSBC_VXM:
// Vector Integer Move Instructions
// Vector Fixed-Point Arithmetic Instructions
// Vector Single-Width Saturating Add and Subtract
// Vector Single-Width Averaging Add and Subtract
// EEW=SEW.
case RISCV::VMV_V_I:
case RISCV::VMV_V_V:
case RISCV::VMV_V_X:
case RISCV::VSADDU_VI:
case RISCV::VSADDU_VV:
case RISCV::VSADDU_VX:
case RISCV::VSADD_VI:
case RISCV::VSADD_VV:
case RISCV::VSADD_VX:
case RISCV::VSSUBU_VV:
case RISCV::VSSUBU_VX:
case RISCV::VSSUB_VV:
case RISCV::VSSUB_VX:
case RISCV::VAADDU_VV:
case RISCV::VAADDU_VX:
case RISCV::VAADD_VV:
case RISCV::VAADD_VX:
case RISCV::VASUBU_VV:
case RISCV::VASUBU_VX:
case RISCV::VASUB_VV:
case RISCV::VASUB_VX:
// Vector Single-Width Fractional Multiply with Rounding and Saturation
// EEW=SEW. The instruction produces 2*SEW product internally but
// saturates to fit into SEW bits.
case RISCV::VSMUL_VV:
case RISCV::VSMUL_VX:
// Vector Single-Width Scaling Shift Instructions
// EEW=SEW.
case RISCV::VSSRL_VI:
case RISCV::VSSRL_VV:
case RISCV::VSSRL_VX:
case RISCV::VSSRA_VI:
case RISCV::VSSRA_VV:
case RISCV::VSSRA_VX:
// Vector Permutation Instructions
// Integer Scalar Move Instructions
// Floating-Point Scalar Move Instructions
// EEW=SEW.
case RISCV::VMV_X_S:
case RISCV::VMV_S_X:
case RISCV::VFMV_F_S:
case RISCV::VFMV_S_F:
// Vector Slide Instructions
// EEW=SEW.
case RISCV::VSLIDEUP_VI:
case RISCV::VSLIDEUP_VX:
case RISCV::VSLIDEDOWN_VI:
case RISCV::VSLIDEDOWN_VX:
case RISCV::VSLIDE1UP_VX:
case RISCV::VFSLIDE1UP_VF:
case RISCV::VSLIDE1DOWN_VX:
case RISCV::VFSLIDE1DOWN_VF:
// Vector Register Gather Instructions
// EEW=SEW. For mask operand, EEW=1.
case RISCV::VRGATHER_VI:
case RISCV::VRGATHER_VV:
case RISCV::VRGATHER_VX:
// Vector Compress Instruction
// EEW=SEW.
case RISCV::VCOMPRESS_VM:
// Vector Element Index Instruction
case RISCV::VID_V:
// Vector Single-Width Floating-Point Add/Subtract Instructions
case RISCV::VFADD_VF:
case RISCV::VFADD_VV:
case RISCV::VFSUB_VF:
case RISCV::VFSUB_VV:
case RISCV::VFRSUB_VF:
// Vector Single-Width Floating-Point Multiply/Divide Instructions
case RISCV::VFMUL_VF:
case RISCV::VFMUL_VV:
case RISCV::VFDIV_VF:
case RISCV::VFDIV_VV:
case RISCV::VFRDIV_VF:
// Vector Single-Width Floating-Point Fused Multiply-Add Instructions
case RISCV::VFMACC_VV:
case RISCV::VFMACC_VF:
case RISCV::VFNMACC_VV:
case RISCV::VFNMACC_VF:
case RISCV::VFMSAC_VV:
case RISCV::VFMSAC_VF:
case RISCV::VFNMSAC_VV:
case RISCV::VFNMSAC_VF:
case RISCV::VFMADD_VV:
case RISCV::VFMADD_VF:
case RISCV::VFNMADD_VV:
case RISCV::VFNMADD_VF:
case RISCV::VFMSUB_VV:
case RISCV::VFMSUB_VF:
case RISCV::VFNMSUB_VV:
case RISCV::VFNMSUB_VF:
// Vector Floating-Point Square-Root Instruction
case RISCV::VFSQRT_V:
// Vector Floating-Point Reciprocal Square-Root Estimate Instruction
case RISCV::VFRSQRT7_V:
// Vector Floating-Point Reciprocal Estimate Instruction
case RISCV::VFREC7_V:
// Vector Floating-Point MIN/MAX Instructions
case RISCV::VFMIN_VF:
case RISCV::VFMIN_VV:
case RISCV::VFMAX_VF:
case RISCV::VFMAX_VV:
// Vector Floating-Point Sign-Injection Instructions
case RISCV::VFSGNJ_VF:
case RISCV::VFSGNJ_VV:
case RISCV::VFSGNJN_VV:
case RISCV::VFSGNJN_VF:
case RISCV::VFSGNJX_VF:
case RISCV::VFSGNJX_VV:
// Vector Floating-Point Classify Instruction
case RISCV::VFCLASS_V:
// Vector Floating-Point Move Instruction
case RISCV::VFMV_V_F:
// Single-Width Floating-Point/Integer Type-Convert Instructions
case RISCV::VFCVT_XU_F_V:
case RISCV::VFCVT_X_F_V:
case RISCV::VFCVT_RTZ_XU_F_V:
case RISCV::VFCVT_RTZ_X_F_V:
case RISCV::VFCVT_F_XU_V:
case RISCV::VFCVT_F_X_V:
// Vector Floating-Point Merge Instruction
case RISCV::VFMERGE_VFM:
// Vector count population in mask vcpop.m
// vfirst find-first-set mask bit
case RISCV::VCPOP_M:
case RISCV::VFIRST_M:
// Vector Bit-manipulation Instructions (Zvbb)
// Vector And-Not
case RISCV::VANDN_VV:
case RISCV::VANDN_VX:
// Vector Reverse Bits in Elements
case RISCV::VBREV_V:
// Vector Reverse Bits in Bytes
case RISCV::VBREV8_V:
// Vector Reverse Bytes
case RISCV::VREV8_V:
// Vector Count Leading Zeros
case RISCV::VCLZ_V:
// Vector Count Trailing Zeros
case RISCV::VCTZ_V:
// Vector Population Count
case RISCV::VCPOP_V:
// Vector Rotate Left
case RISCV::VROL_VV:
case RISCV::VROL_VX:
// Vector Rotate Right
case RISCV::VROR_VI:
case RISCV::VROR_VV:
case RISCV::VROR_VX:
// Vector Carry-less Multiplication Instructions (Zvbc)
// Vector Carry-less Multiply
case RISCV::VCLMUL_VV:
case RISCV::VCLMUL_VX:
// Vector Carry-less Multiply Return High Half
case RISCV::VCLMULH_VV:
case RISCV::VCLMULH_VX:
return MILog2SEW;
// Vector Widening Shift Left Logical (Zvbb)
case RISCV::VWSLL_VI:
case RISCV::VWSLL_VX:
case RISCV::VWSLL_VV:
// Vector Widening Integer Add/Subtract
// Def uses EEW=2*SEW . Operands use EEW=SEW.
case RISCV::VWADDU_VV:
case RISCV::VWADDU_VX:
case RISCV::VWSUBU_VV:
case RISCV::VWSUBU_VX:
case RISCV::VWADD_VV:
case RISCV::VWADD_VX:
case RISCV::VWSUB_VV:
case RISCV::VWSUB_VX:
// Vector Widening Integer Multiply Instructions
// Destination EEW=2*SEW. Source EEW=SEW.
case RISCV::VWMUL_VV:
case RISCV::VWMUL_VX:
case RISCV::VWMULSU_VV:
case RISCV::VWMULSU_VX:
case RISCV::VWMULU_VV:
case RISCV::VWMULU_VX:
// Vector Widening Integer Multiply-Add Instructions
// Destination EEW=2*SEW. Source EEW=SEW.
// A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
// is then added to the 2*SEW-bit Dest. These instructions never have a
// passthru operand.
case RISCV::VWMACCU_VV:
case RISCV::VWMACCU_VX:
case RISCV::VWMACC_VV:
case RISCV::VWMACC_VX:
case RISCV::VWMACCSU_VV:
case RISCV::VWMACCSU_VX:
case RISCV::VWMACCUS_VX:
// Vector Widening Floating-Point Fused Multiply-Add Instructions
case RISCV::VFWMACC_VF:
case RISCV::VFWMACC_VV:
case RISCV::VFWNMACC_VF:
case RISCV::VFWNMACC_VV:
case RISCV::VFWMSAC_VF:
case RISCV::VFWMSAC_VV:
case RISCV::VFWNMSAC_VF:
case RISCV::VFWNMSAC_VV:
case RISCV::VFWMACCBF16_VV:
case RISCV::VFWMACCBF16_VF:
// Vector Widening Floating-Point Add/Subtract Instructions
// Dest EEW=2*SEW. Source EEW=SEW.
case RISCV::VFWADD_VV:
case RISCV::VFWADD_VF:
case RISCV::VFWSUB_VV:
case RISCV::VFWSUB_VF:
// Vector Widening Floating-Point Multiply
case RISCV::VFWMUL_VF:
case RISCV::VFWMUL_VV:
// Widening Floating-Point/Integer Type-Convert Instructions
case RISCV::VFWCVT_XU_F_V:
case RISCV::VFWCVT_X_F_V:
case RISCV::VFWCVT_RTZ_XU_F_V:
case RISCV::VFWCVT_RTZ_X_F_V:
case RISCV::VFWCVT_F_XU_V:
case RISCV::VFWCVT_F_X_V:
case RISCV::VFWCVT_F_F_V:
case RISCV::VFWCVTBF16_F_F_V:
return IsMODef ? MILog2SEW + 1 : MILog2SEW;
// Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
case RISCV::VWADDU_WV:
case RISCV::VWADDU_WX:
case RISCV::VWSUBU_WV:
case RISCV::VWSUBU_WX:
case RISCV::VWADD_WV:
case RISCV::VWADD_WX:
case RISCV::VWSUB_WV:
case RISCV::VWSUB_WX:
// Vector Widening Floating-Point Add/Subtract Instructions
case RISCV::VFWADD_WF:
case RISCV::VFWADD_WV:
case RISCV::VFWSUB_WF:
case RISCV::VFWSUB_WV: {
bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2
: MO.getOperandNo() == 1;
bool TwoTimes = IsMODef || IsOp1;
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
}
// Vector Integer Extension
case RISCV::VZEXT_VF2:
case RISCV::VSEXT_VF2:
return getIntegerExtensionOperandEEW(2, MI, MO);
case RISCV::VZEXT_VF4:
case RISCV::VSEXT_VF4:
return getIntegerExtensionOperandEEW(4, MI, MO);
case RISCV::VZEXT_VF8:
case RISCV::VSEXT_VF8:
return getIntegerExtensionOperandEEW(8, MI, MO);
// Vector Narrowing Integer Right Shift Instructions
// Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
case RISCV::VNSRL_WX:
case RISCV::VNSRL_WI:
case RISCV::VNSRL_WV:
case RISCV::VNSRA_WI:
case RISCV::VNSRA_WV:
case RISCV::VNSRA_WX:
// Vector Narrowing Fixed-Point Clip Instructions
// Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
case RISCV::VNCLIPU_WI:
case RISCV::VNCLIPU_WV:
case RISCV::VNCLIPU_WX:
case RISCV::VNCLIP_WI:
case RISCV::VNCLIP_WV:
case RISCV::VNCLIP_WX:
// Narrowing Floating-Point/Integer Type-Convert Instructions
case RISCV::VFNCVT_XU_F_W:
case RISCV::VFNCVT_X_F_W:
case RISCV::VFNCVT_RTZ_XU_F_W:
case RISCV::VFNCVT_RTZ_X_F_W:
case RISCV::VFNCVT_F_XU_W:
case RISCV::VFNCVT_F_X_W:
case RISCV::VFNCVT_F_F_W:
case RISCV::VFNCVT_ROD_F_F_W:
case RISCV::VFNCVTBF16_F_F_W: {
assert(!IsTied);
bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
bool TwoTimes = IsOp1;
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
}
// Vector Mask Instructions
// Vector Mask-Register Logical Instructions
// vmsbf.m set-before-first mask bit
// vmsif.m set-including-first mask bit
// vmsof.m set-only-first mask bit
// EEW=1
// We handle the cases when operand is a v0 mask operand above the switch,
// but these instructions may use non-v0 mask operands and need to be handled
// specifically.
case RISCV::VMAND_MM:
case RISCV::VMNAND_MM:
case RISCV::VMANDN_MM:
case RISCV::VMXOR_MM:
case RISCV::VMOR_MM:
case RISCV::VMNOR_MM:
case RISCV::VMORN_MM:
case RISCV::VMXNOR_MM:
case RISCV::VMSBF_M:
case RISCV::VMSIF_M:
case RISCV::VMSOF_M: {
return MILog2SEW;
}
// Vector Iota Instruction
// EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
// before this switch.
case RISCV::VIOTA_M: {
if (IsMODef || MO.getOperandNo() == 1)
return MILog2SEW;
return 0;
}
// Vector Integer Compare Instructions
// Dest EEW=1. Source EEW=SEW.
case RISCV::VMSEQ_VI:
case RISCV::VMSEQ_VV:
case RISCV::VMSEQ_VX:
case RISCV::VMSNE_VI:
case RISCV::VMSNE_VV:
case RISCV::VMSNE_VX:
case RISCV::VMSLTU_VV:
case RISCV::VMSLTU_VX:
case RISCV::VMSLT_VV:
case RISCV::VMSLT_VX:
case RISCV::VMSLEU_VV:
case RISCV::VMSLEU_VI:
case RISCV::VMSLEU_VX:
case RISCV::VMSLE_VV:
case RISCV::VMSLE_VI:
case RISCV::VMSLE_VX:
case RISCV::VMSGTU_VI:
case RISCV::VMSGTU_VX:
case RISCV::VMSGT_VI:
case RISCV::VMSGT_VX:
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
// Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
case RISCV::VMADC_VIM:
case RISCV::VMADC_VVM:
case RISCV::VMADC_VXM:
case RISCV::VMSBC_VVM:
case RISCV::VMSBC_VXM:
// Dest EEW=1. Source EEW=SEW.
case RISCV::VMADC_VV:
case RISCV::VMADC_VI:
case RISCV::VMADC_VX:
case RISCV::VMSBC_VV:
case RISCV::VMSBC_VX:
// 13.13. Vector Floating-Point Compare Instructions
// Dest EEW=1. Source EEW=SEW
case RISCV::VMFEQ_VF:
case RISCV::VMFEQ_VV:
case RISCV::VMFNE_VF:
case RISCV::VMFNE_VV:
case RISCV::VMFLT_VF:
case RISCV::VMFLT_VV:
case RISCV::VMFLE_VF:
case RISCV::VMFLE_VV:
case RISCV::VMFGT_VF:
case RISCV::VMFGE_VF: {
if (IsMODef)
return 0;
return MILog2SEW;
}
// Vector Reduction Operations
// Vector Single-Width Integer Reduction Instructions
case RISCV::VREDAND_VS:
case RISCV::VREDMAX_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS:
// Vector Single-Width Floating-Point Reduction Instructions
case RISCV::VFREDMAX_VS:
case RISCV::VFREDMIN_VS:
case RISCV::VFREDOSUM_VS:
case RISCV::VFREDUSUM_VS: {
return MILog2SEW;
}
// Vector Widening Integer Reduction Instructions
// The Dest and VS1 read only element 0 for the vector register. Return
// 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
case RISCV::VWREDSUM_VS:
case RISCV::VWREDSUMU_VS:
// Vector Widening Floating-Point Reduction Instructions
case RISCV::VFWREDOSUM_VS:
case RISCV::VFWREDUSUM_VS: {
bool TwoTimes = IsMODef || MO.getOperandNo() == 3;
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
}
// Vector Register Gather with 16-bit Index Elements Instruction
// Dest and source data EEW=SEW. Index vector EEW=16.
case RISCV::VRGATHEREI16_VV: {
if (MO.getOperandNo() == 2)
return 4;
return MILog2SEW;
}
default:
return std::nullopt;
}
}
static std::optional<OperandInfo>
getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
const MachineInstr &MI = *MO.getParent();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
assert(RVV && "Could not find MI in PseudoTable");
std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
if (!Log2EEW)
return std::nullopt;
switch (RVV->BaseInstr) {
// Vector Reduction Operations
// Vector Single-Width Integer Reduction Instructions
// Vector Widening Integer Reduction Instructions
// Vector Widening Floating-Point Reduction Instructions
// The Dest and VS1 only read element 0 of the vector register. Return just
// the EEW for these.
case RISCV::VREDAND_VS:
case RISCV::VREDMAX_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS:
case RISCV::VWREDSUM_VS:
case RISCV::VWREDSUMU_VS:
case RISCV::VFWREDOSUM_VS:
case RISCV::VFWREDUSUM_VS:
if (MO.getOperandNo() != 2)
return OperandInfo(*Log2EEW);
break;
};
// All others have EMUL=EEW/SEW*LMUL
return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), *Log2EEW);
}
/// Return true if this optimization should consider MI for VL reduction. This
/// white-list approach simplifies this optimization for instructions that may
/// have more complex semantics with relation to how it uses VL.
static bool isSupportedInstr(const MachineInstr &MI) {
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
if (!RVV)
return false;
switch (RVV->BaseInstr) {
// Vector Unit-Stride Instructions
// Vector Strided Instructions
case RISCV::VLM_V:
case RISCV::VLE8_V:
case RISCV::VLSE8_V:
case RISCV::VLE16_V:
case RISCV::VLSE16_V:
case RISCV::VLE32_V:
case RISCV::VLSE32_V:
case RISCV::VLE64_V:
case RISCV::VLSE64_V:
// Vector Indexed Instructions
case RISCV::VLUXEI8_V:
case RISCV::VLOXEI8_V:
case RISCV::VLUXEI16_V:
case RISCV::VLOXEI16_V:
case RISCV::VLUXEI32_V:
case RISCV::VLOXEI32_V:
case RISCV::VLUXEI64_V:
case RISCV::VLOXEI64_V: {
for (const MachineMemOperand *MMO : MI.memoperands())
if (MMO->isVolatile())
return false;
return true;
}
// Vector Single-Width Integer Add and Subtract
case RISCV::VADD_VI:
case RISCV::VADD_VV:
case RISCV::VADD_VX:
case RISCV::VSUB_VV:
case RISCV::VSUB_VX:
case RISCV::VRSUB_VI:
case RISCV::VRSUB_VX:
// Vector Bitwise Logical Instructions
// Vector Single-Width Shift Instructions
case RISCV::VAND_VI:
case RISCV::VAND_VV:
case RISCV::VAND_VX:
case RISCV::VOR_VI:
case RISCV::VOR_VV:
case RISCV::VOR_VX:
case RISCV::VXOR_VI:
case RISCV::VXOR_VV:
case RISCV::VXOR_VX:
case RISCV::VSLL_VI:
case RISCV::VSLL_VV:
case RISCV::VSLL_VX:
case RISCV::VSRL_VI:
case RISCV::VSRL_VV:
case RISCV::VSRL_VX:
case RISCV::VSRA_VI:
case RISCV::VSRA_VV:
case RISCV::VSRA_VX:
// Vector Widening Integer Add/Subtract
case RISCV::VWADDU_VV:
case RISCV::VWADDU_VX:
case RISCV::VWSUBU_VV:
case RISCV::VWSUBU_VX:
case RISCV::VWADD_VV:
case RISCV::VWADD_VX:
case RISCV::VWSUB_VV:
case RISCV::VWSUB_VX:
case RISCV::VWADDU_WV:
case RISCV::VWADDU_WX:
case RISCV::VWSUBU_WV:
case RISCV::VWSUBU_WX:
case RISCV::VWADD_WV:
case RISCV::VWADD_WX:
case RISCV::VWSUB_WV:
case RISCV::VWSUB_WX:
// Vector Integer Extension
case RISCV::VZEXT_VF2:
case RISCV::VSEXT_VF2:
case RISCV::VZEXT_VF4:
case RISCV::VSEXT_VF4:
case RISCV::VZEXT_VF8:
case RISCV::VSEXT_VF8:
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
// FIXME: Add support
case RISCV::VMADC_VV:
case RISCV::VMADC_VI:
case RISCV::VMADC_VX:
case RISCV::VMSBC_VV:
case RISCV::VMSBC_VX:
// Vector Narrowing Integer Right Shift Instructions
case RISCV::VNSRL_WX:
case RISCV::VNSRL_WI:
case RISCV::VNSRL_WV:
case RISCV::VNSRA_WI:
case RISCV::VNSRA_WV:
case RISCV::VNSRA_WX:
// Vector Integer Compare Instructions
case RISCV::VMSEQ_VI:
case RISCV::VMSEQ_VV:
case RISCV::VMSEQ_VX:
case RISCV::VMSNE_VI:
case RISCV::VMSNE_VV:
case RISCV::VMSNE_VX:
case RISCV::VMSLTU_VV:
case RISCV::VMSLTU_VX:
case RISCV::VMSLT_VV:
case RISCV::VMSLT_VX:
case RISCV::VMSLEU_VV:
case RISCV::VMSLEU_VI:
case RISCV::VMSLEU_VX:
case RISCV::VMSLE_VV:
case RISCV::VMSLE_VI:
case RISCV::VMSLE_VX:
case RISCV::VMSGTU_VI:
case RISCV::VMSGTU_VX:
case RISCV::VMSGT_VI:
case RISCV::VMSGT_VX:
// Vector Integer Min/Max Instructions
case RISCV::VMINU_VV:
case RISCV::VMINU_VX:
case RISCV::VMIN_VV:
case RISCV::VMIN_VX:
case RISCV::VMAXU_VV:
case RISCV::VMAXU_VX:
case RISCV::VMAX_VV:
case RISCV::VMAX_VX:
// Vector Single-Width Integer Multiply Instructions
case RISCV::VMUL_VV:
case RISCV::VMUL_VX:
case RISCV::VMULH_VV:
case RISCV::VMULH_VX:
case RISCV::VMULHU_VV:
case RISCV::VMULHU_VX:
case RISCV::VMULHSU_VV:
case RISCV::VMULHSU_VX:
// Vector Integer Divide Instructions
case RISCV::VDIVU_VV:
case RISCV::VDIVU_VX:
case RISCV::VDIV_VV:
case RISCV::VDIV_VX:
case RISCV::VREMU_VV:
case RISCV::VREMU_VX:
case RISCV::VREM_VV:
case RISCV::VREM_VX:
// Vector Widening Integer Multiply Instructions
case RISCV::VWMUL_VV:
case RISCV::VWMUL_VX:
case RISCV::VWMULSU_VV:
case RISCV::VWMULSU_VX:
case RISCV::VWMULU_VV:
case RISCV::VWMULU_VX:
// Vector Single-Width Integer Multiply-Add Instructions
case RISCV::VMACC_VV:
case RISCV::VMACC_VX:
case RISCV::VNMSAC_VV:
case RISCV::VNMSAC_VX:
case RISCV::VMADD_VV:
case RISCV::VMADD_VX:
case RISCV::VNMSUB_VV:
case RISCV::VNMSUB_VX:
// Vector Integer Merge Instructions
case RISCV::VMERGE_VIM:
case RISCV::VMERGE_VVM:
case RISCV::VMERGE_VXM:
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
case RISCV::VADC_VIM:
case RISCV::VADC_VVM:
case RISCV::VADC_VXM:
case RISCV::VMADC_VIM:
case RISCV::VMADC_VVM:
case RISCV::VMADC_VXM:
case RISCV::VSBC_VVM:
case RISCV::VSBC_VXM:
case RISCV::VMSBC_VVM:
case RISCV::VMSBC_VXM:
// Vector Widening Integer Multiply-Add Instructions
case RISCV::VWMACCU_VV:
case RISCV::VWMACCU_VX:
case RISCV::VWMACC_VV:
case RISCV::VWMACC_VX:
case RISCV::VWMACCSU_VV:
case RISCV::VWMACCSU_VX:
case RISCV::VWMACCUS_VX:
// Vector Integer Merge Instructions
// FIXME: Add support
// Vector Integer Move Instructions
// FIXME: Add support
case RISCV::VMV_V_I:
case RISCV::VMV_V_X:
case RISCV::VMV_V_V:
// Vector Single-Width Saturating Add and Subtract
case RISCV::VSADDU_VV:
case RISCV::VSADDU_VX:
case RISCV::VSADDU_VI:
case RISCV::VSADD_VV:
case RISCV::VSADD_VX:
case RISCV::VSADD_VI:
case RISCV::VSSUBU_VV:
case RISCV::VSSUBU_VX:
case RISCV::VSSUB_VV:
case RISCV::VSSUB_VX:
// Vector Single-Width Averaging Add and Subtract
case RISCV::VAADDU_VV:
case RISCV::VAADDU_VX:
case RISCV::VAADD_VV:
case RISCV::VAADD_VX:
case RISCV::VASUBU_VV:
case RISCV::VASUBU_VX:
case RISCV::VASUB_VV:
case RISCV::VASUB_VX:
// Vector Single-Width Fractional Multiply with Rounding and Saturation
case RISCV::VSMUL_VV:
case RISCV::VSMUL_VX:
// Vector Single-Width Scaling Shift Instructions
case RISCV::VSSRL_VV:
case RISCV::VSSRL_VX:
case RISCV::VSSRL_VI:
case RISCV::VSSRA_VV:
case RISCV::VSSRA_VX:
case RISCV::VSSRA_VI:
// Vector Narrowing Fixed-Point Clip Instructions
case RISCV::VNCLIPU_WV:
case RISCV::VNCLIPU_WX:
case RISCV::VNCLIPU_WI:
case RISCV::VNCLIP_WV:
case RISCV::VNCLIP_WX:
case RISCV::VNCLIP_WI:
// Vector Bit-manipulation Instructions (Zvbb)
// Vector And-Not
case RISCV::VANDN_VV:
case RISCV::VANDN_VX:
// Vector Reverse Bits in Elements
case RISCV::VBREV_V:
// Vector Reverse Bits in Bytes
case RISCV::VBREV8_V:
// Vector Reverse Bytes
case RISCV::VREV8_V:
// Vector Count Leading Zeros
case RISCV::VCLZ_V:
// Vector Count Trailing Zeros
case RISCV::VCTZ_V:
// Vector Population Count
case RISCV::VCPOP_V:
// Vector Rotate Left
case RISCV::VROL_VV:
case RISCV::VROL_VX:
// Vector Rotate Right
case RISCV::VROR_VI:
case RISCV::VROR_VV:
case RISCV::VROR_VX:
// Vector Widening Shift Left Logical
case RISCV::VWSLL_VI:
case RISCV::VWSLL_VX:
case RISCV::VWSLL_VV:
// Vector Carry-less Multiplication Instructions (Zvbc)
// Vector Carry-less Multiply
case RISCV::VCLMUL_VV:
case RISCV::VCLMUL_VX:
// Vector Carry-less Multiply Return High Half
case RISCV::VCLMULH_VV:
case RISCV::VCLMULH_VX:
// Vector Mask Instructions
// Vector Mask-Register Logical Instructions
// vmsbf.m set-before-first mask bit
// vmsif.m set-including-first mask bit
// vmsof.m set-only-first mask bit
// Vector Iota Instruction
// Vector Element Index Instruction
case RISCV::VMAND_MM:
case RISCV::VMNAND_MM:
case RISCV::VMANDN_MM:
case RISCV::VMXOR_MM:
case RISCV::VMOR_MM:
case RISCV::VMNOR_MM:
case RISCV::VMORN_MM:
case RISCV::VMXNOR_MM:
case RISCV::VMSBF_M:
case RISCV::VMSIF_M:
case RISCV::VMSOF_M:
case RISCV::VIOTA_M:
case RISCV::VID_V:
// Vector Slide Instructions
case RISCV::VSLIDEUP_VX:
case RISCV::VSLIDEUP_VI:
case RISCV::VSLIDEDOWN_VX:
case RISCV::VSLIDEDOWN_VI:
case RISCV::VSLIDE1UP_VX:
case RISCV::VFSLIDE1UP_VF:
// Vector Register Gather Instructions
case RISCV::VRGATHER_VI:
case RISCV::VRGATHER_VV:
case RISCV::VRGATHER_VX:
case RISCV::VRGATHEREI16_VV:
// Vector Single-Width Floating-Point Add/Subtract Instructions
case RISCV::VFADD_VF:
case RISCV::VFADD_VV:
case RISCV::VFSUB_VF:
case RISCV::VFSUB_VV:
case RISCV::VFRSUB_VF:
// Vector Widening Floating-Point Add/Subtract Instructions
case RISCV::VFWADD_VV:
case RISCV::VFWADD_VF:
case RISCV::VFWSUB_VV:
case RISCV::VFWSUB_VF:
case RISCV::VFWADD_WF:
case RISCV::VFWADD_WV:
case RISCV::VFWSUB_WF:
case RISCV::VFWSUB_WV:
// Vector Single-Width Floating-Point Multiply/Divide Instructions
case RISCV::VFMUL_VF:
case RISCV::VFMUL_VV:
case RISCV::VFDIV_VF:
case RISCV::VFDIV_VV:
case RISCV::VFRDIV_VF:
// Vector Widening Floating-Point Multiply
case RISCV::VFWMUL_VF:
case RISCV::VFWMUL_VV:
// Vector Single-Width Floating-Point Fused Multiply-Add Instructions
case RISCV::VFMACC_VV:
case RISCV::VFMACC_VF:
case RISCV::VFNMACC_VV:
case RISCV::VFNMACC_VF:
case RISCV::VFMSAC_VV:
case RISCV::VFMSAC_VF:
case RISCV::VFNMSAC_VV:
case RISCV::VFNMSAC_VF:
case RISCV::VFMADD_VV:
case RISCV::VFMADD_VF:
case RISCV::VFNMADD_VV:
case RISCV::VFNMADD_VF:
case RISCV::VFMSUB_VV:
case RISCV::VFMSUB_VF:
case RISCV::VFNMSUB_VV:
case RISCV::VFNMSUB_VF:
// Vector Widening Floating-Point Fused Multiply-Add Instructions
case RISCV::VFWMACC_VV:
case RISCV::VFWMACC_VF:
case RISCV::VFWNMACC_VV:
case RISCV::VFWNMACC_VF:
case RISCV::VFWMSAC_VV:
case RISCV::VFWMSAC_VF:
case RISCV::VFWNMSAC_VV:
case RISCV::VFWNMSAC_VF:
case RISCV::VFWMACCBF16_VV:
case RISCV::VFWMACCBF16_VF:
// Vector Floating-Point Square-Root Instruction
case RISCV::VFSQRT_V:
// Vector Floating-Point Reciprocal Square-Root Estimate Instruction
case RISCV::VFRSQRT7_V:
// Vector Floating-Point Reciprocal Estimate Instruction
case RISCV::VFREC7_V:
// Vector Floating-Point MIN/MAX Instructions
case RISCV::VFMIN_VF:
case RISCV::VFMIN_VV:
case RISCV::VFMAX_VF:
case RISCV::VFMAX_VV:
// Vector Floating-Point Sign-Injection Instructions
case RISCV::VFSGNJ_VF:
case RISCV::VFSGNJ_VV:
case RISCV::VFSGNJN_VV:
case RISCV::VFSGNJN_VF:
case RISCV::VFSGNJX_VF:
case RISCV::VFSGNJX_VV:
// Vector Floating-Point Compare Instructions
case RISCV::VMFEQ_VF:
case RISCV::VMFEQ_VV:
case RISCV::VMFNE_VF:
case RISCV::VMFNE_VV:
case RISCV::VMFLT_VF:
case RISCV::VMFLT_VV:
case RISCV::VMFLE_VF:
case RISCV::VMFLE_VV:
case RISCV::VMFGT_VF:
case RISCV::VMFGE_VF:
// Vector Floating-Point Classify Instruction
case RISCV::VFCLASS_V:
// Vector Floating-Point Merge Instruction
case RISCV::VFMERGE_VFM:
// Vector Floating-Point Move Instruction
case RISCV::VFMV_V_F:
// Single-Width Floating-Point/Integer Type-Convert Instructions
case RISCV::VFCVT_XU_F_V:
case RISCV::VFCVT_X_F_V:
case RISCV::VFCVT_RTZ_XU_F_V:
case RISCV::VFCVT_RTZ_X_F_V:
case RISCV::VFCVT_F_XU_V:
case RISCV::VFCVT_F_X_V:
// Widening Floating-Point/Integer Type-Convert Instructions
case RISCV::VFWCVT_XU_F_V:
case RISCV::VFWCVT_X_F_V:
case RISCV::VFWCVT_RTZ_XU_F_V:
case RISCV::VFWCVT_RTZ_X_F_V:
case RISCV::VFWCVT_F_XU_V:
case RISCV::VFWCVT_F_X_V:
case RISCV::VFWCVT_F_F_V:
case RISCV::VFWCVTBF16_F_F_V:
// Narrowing Floating-Point/Integer Type-Convert Instructions
case RISCV::VFNCVT_XU_F_W:
case RISCV::VFNCVT_X_F_W:
case RISCV::VFNCVT_RTZ_XU_F_W:
case RISCV::VFNCVT_RTZ_X_F_W:
case RISCV::VFNCVT_F_XU_W:
case RISCV::VFNCVT_F_X_W:
case RISCV::VFNCVT_F_F_W:
case RISCV::VFNCVT_ROD_F_F_W:
case RISCV::VFNCVTBF16_F_F_W:
return true;
}
return false;
}
/// Return true if MO is a vector operand but is used as a scalar operand.
static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) {
const MachineInstr *MI = MO.getParent();
const RISCVVPseudosTable::PseudoInfo *RVV =
RISCVVPseudosTable::getPseudoInfo(MI->getOpcode());
if (!RVV)
return false;
switch (RVV->BaseInstr) {
// Reductions only use vs1[0] of vs1
case RISCV::VREDAND_VS:
case RISCV::VREDMAX_VS:
case RISCV::VREDMAXU_VS:
case RISCV::VREDMIN_VS:
case RISCV::VREDMINU_VS:
case RISCV::VREDOR_VS:
case RISCV::VREDSUM_VS:
case RISCV::VREDXOR_VS:
case RISCV::VWREDSUM_VS:
case RISCV::VWREDSUMU_VS:
case RISCV::VFREDMAX_VS:
case RISCV::VFREDMIN_VS:
case RISCV::VFREDOSUM_VS:
case RISCV::VFREDUSUM_VS:
case RISCV::VFWREDOSUM_VS:
case RISCV::VFWREDUSUM_VS:
return MO.getOperandNo() == 3;
case RISCV::VMV_X_S:
case RISCV::VFMV_F_S:
return MO.getOperandNo() == 1;
default:
return false;
}
}
bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
const MCInstrDesc &Desc = MI.getDesc();
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags))
return false;
if (MI.getNumExplicitDefs() != 1)
return false;
// Some instructions have implicit defs e.g. $vxsat. If they might be read
// later then we can't reduce VL.
if (!MI.allImplicitDefsAreDead()) {
LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n");
return false;
}
if (MI.mayRaiseFPException()) {
LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
return false;
}
// Some instructions that produce vectors have semantics that make it more
// difficult to determine whether the VL can be reduced. For example, some
// instructions, such as reductions, may write lanes past VL to a scalar
// register. Other instructions, such as some loads or stores, may write
// lower lanes using data from higher lanes. There may be other complex
// semantics not mentioned here that make it hard to determine whether
// the VL can be optimized. As a result, a white-list of supported
// instructions is used. Over time, more instructions can be supported
// upon careful examination of their semantics under the logic in this
// optimization.
// TODO: Use a better approach than a white-list, such as adding
// properties to instructions using something like TSFlags.
if (!isSupportedInstr(MI)) {
LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n");
return false;
}
assert(!RISCVII::elementsDependOnVL(
TII->get(RISCV::getRVVMCOpcode(MI.getOpcode())).TSFlags) &&
"Instruction shouldn't be supported if elements depend on VL");
assert(RISCVRI::isVRegClass(
MRI->getRegClass(MI.getOperand(0).getReg())->TSFlags) &&
"All supported instructions produce a vector register result");
LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
return true;
}
std::optional<MachineOperand>
RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
const MachineInstr &UserMI = *UserOp.getParent();
const MCInstrDesc &Desc = UserMI.getDesc();
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
" use VLMAX\n");
return std::nullopt;
}
if (RISCVII::readsPastVL(
TII->get(RISCV::getRVVMCOpcode(UserMI.getOpcode())).TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
return std::nullopt;
}
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
// Looking for an immediate or a register VL that isn't X0.
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
"Did not expect X0 VL");
// If the user is a passthru it will read the elements past VL, so
// abort if any of the elements past VL are demanded.
if (UserOp.isTied()) {
assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
auto DemandedVL = DemandedVLs.lookup(&UserMI);
if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) {
LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
"instruction with demanded tail\n");
return std::nullopt;
}
}
// Instructions like reductions may use a vector register as a scalar
// register. In this case, we should treat it as only reading the first lane.
if (isVectorOpUsedAsScalarOp(UserOp)) {
LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n");
return MachineOperand::CreateImm(1);
}
// If we know the demanded VL of UserMI, then we can reduce the VL it
// requires.
if (auto DemandedVL = DemandedVLs.lookup(&UserMI)) {
assert(isCandidate(UserMI));
if (RISCV::isVLKnownLE(*DemandedVL, VLOp))
return DemandedVL;
}
return VLOp;
}
std::optional<MachineOperand>
RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
std::optional<MachineOperand> CommonVL;
SmallSetVector<MachineOperand *, 8> Worklist;
SmallPtrSet<const MachineInstr *, 4> PHISeen;
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg()))
Worklist.insert(&UserOp);
while (!Worklist.empty()) {
MachineOperand &UserOp = *Worklist.pop_back_val();
const MachineInstr &UserMI = *UserOp.getParent();
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
if (UserMI.isFullCopy() && UserMI.getOperand(0).getReg().isVirtual()) {
LLVM_DEBUG(dbgs() << " Peeking through uses of COPY\n");
Worklist.insert_range(llvm::make_pointer_range(
MRI->use_operands(UserMI.getOperand(0).getReg())));
continue;
}
if (UserMI.isPHI()) {
// Don't follow PHI cycles
if (!PHISeen.insert(&UserMI).second)
continue;
LLVM_DEBUG(dbgs() << " Peeking through uses of PHI\n");
Worklist.insert_range(llvm::make_pointer_range(
MRI->use_operands(UserMI.getOperand(0).getReg())));
continue;
}
auto VLOp = getMinimumVLForUser(UserOp);
if (!VLOp)
return std::nullopt;
// Use the largest VL among all the users. If we cannot determine this
// statically, then we cannot optimize the VL.
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
CommonVL = *VLOp;
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
return std::nullopt;
}
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
return std::nullopt;
}
std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
std::optional<OperandInfo> ProducerInfo =
getOperandInfo(MI.getOperand(0), MRI);
if (!ConsumerInfo || !ProducerInfo) {
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
return std::nullopt;
}
if (!OperandInfo::areCompatible(*ProducerInfo, *ConsumerInfo)) {
LLVM_DEBUG(
dbgs()
<< " Abort due to incompatible information for EMUL or EEW.\n");
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
return std::nullopt;
}
}
return CommonVL;
}
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
MachineOperand &VLOp = MI.getOperand(VLOpNum);
// If the VL is 1, then there is no need to reduce it. This is an
// optimization, not needed to preserve correctness.
if (VLOp.isImm() && VLOp.getImm() == 1) {
LLVM_DEBUG(dbgs() << " Abort due to VL == 1, no point in reducing.\n");
return false;
}
auto CommonVL = DemandedVLs.lookup(&MI);
if (!CommonVL)
return false;
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
"Expected VL to be an Imm or virtual Reg");
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
return false;
}
if (CommonVL->isIdenticalTo(VLOp)) {
LLVM_DEBUG(
dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
return false;
}
if (CommonVL->isImm()) {
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
<< CommonVL->getImm() << " for " << MI << "\n");
VLOp.ChangeToImmediate(CommonVL->getImm());
return true;
}
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
if (!MDT->dominates(VLMI, &MI))
return false;
LLVM_DEBUG(
dbgs() << " Reduce VL from " << VLOp << " to "
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
<< " for " << MI << "\n");
// All our checks passed. We can reduce VL.
VLOp.ChangeToRegister(CommonVL->getReg(), false);
return true;
}
bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()))
return false;
MRI = &MF.getRegInfo();
MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
if (!ST.hasVInstructions())
return false;
TII = ST.getInstrInfo();
assert(DemandedVLs.empty());
// For each instruction that defines a vector, compute what VL its
// downstream users demand.
for (MachineBasicBlock *MBB : post_order(&MF)) {
assert(MDT->isReachableFromEntry(MBB));
for (MachineInstr &MI : reverse(*MBB)) {
if (!isCandidate(MI))
continue;
DemandedVLs.insert({&MI, checkUsers(MI)});
}
}
// Then go through and see if we can reduce the VL of any instructions to
// only what's demanded.
bool MadeChange = false;
for (MachineBasicBlock &MBB : MF) {
// Avoid unreachable blocks as they have degenerate dominance
if (!MDT->isReachableFromEntry(&MBB))
continue;
for (auto &MI : reverse(MBB)) {
if (!isCandidate(MI))
continue;
if (!tryReduceVL(MI))
continue;
MadeChange = true;
}
}
DemandedVLs.clear();
return MadeChange;
}