1494 lines
45 KiB
C++
1494 lines
45 KiB
C++
//===-------------- RISCVVLOptimizer.cpp - VL Optimizer -------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===---------------------------------------------------------------------===//
|
|
//
|
|
// This pass reduces the VL where possible at the MI level, before VSETVLI
|
|
// instructions are inserted.
|
|
//
|
|
// The purpose of this optimization is to make the VL argument, for instructions
|
|
// that have a VL argument, as small as possible. This is implemented by
|
|
// visiting each instruction in reverse order and checking that if it has a VL
|
|
// argument, whether the VL can be reduced.
|
|
//
|
|
//===---------------------------------------------------------------------===//
|
|
|
|
#include "RISCV.h"
|
|
#include "RISCVSubtarget.h"
|
|
#include "llvm/ADT/PostOrderIterator.h"
|
|
#include "llvm/CodeGen/MachineDominators.h"
|
|
#include "llvm/CodeGen/MachineFunctionPass.h"
|
|
#include "llvm/InitializePasses.h"
|
|
|
|
using namespace llvm;
|
|
|
|
#define DEBUG_TYPE "riscv-vl-optimizer"
|
|
#define PASS_NAME "RISC-V VL Optimizer"
|
|
|
|
namespace {
|
|
|
|
class RISCVVLOptimizer : public MachineFunctionPass {
|
|
const MachineRegisterInfo *MRI;
|
|
const MachineDominatorTree *MDT;
|
|
|
|
public:
|
|
static char ID;
|
|
|
|
RISCVVLOptimizer() : MachineFunctionPass(ID) {}
|
|
|
|
bool runOnMachineFunction(MachineFunction &MF) override;
|
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
AU.setPreservesCFG();
|
|
AU.addRequired<MachineDominatorTreeWrapperPass>();
|
|
MachineFunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
|
|
StringRef getPassName() const override { return PASS_NAME; }
|
|
|
|
private:
|
|
std::optional<MachineOperand>
|
|
getMinimumVLForUser(const MachineOperand &UserOp) const;
|
|
/// Returns the largest common VL MachineOperand that may be used to optimize
|
|
/// MI. Returns std::nullopt if it failed to find a suitable VL.
|
|
std::optional<MachineOperand> checkUsers(const MachineInstr &MI) const;
|
|
bool tryReduceVL(MachineInstr &MI) const;
|
|
bool isCandidate(const MachineInstr &MI) const;
|
|
|
|
/// For a given instruction, records what elements of it are demanded by
|
|
/// downstream users.
|
|
DenseMap<const MachineInstr *, std::optional<MachineOperand>> DemandedVLs;
|
|
};
|
|
|
|
/// Represents the EMUL and EEW of a MachineOperand.
|
|
struct OperandInfo {
|
|
// Represent as 1,2,4,8, ... and fractional indicator. This is because
|
|
// EMUL can take on values that don't map to RISCVVType::VLMUL values exactly.
|
|
// For example, a mask operand can have an EMUL less than MF8.
|
|
std::optional<std::pair<unsigned, bool>> EMUL;
|
|
|
|
unsigned Log2EEW;
|
|
|
|
OperandInfo(RISCVVType::VLMUL EMUL, unsigned Log2EEW)
|
|
: EMUL(RISCVVType::decodeVLMUL(EMUL)), Log2EEW(Log2EEW) {}
|
|
|
|
OperandInfo(std::pair<unsigned, bool> EMUL, unsigned Log2EEW)
|
|
: EMUL(EMUL), Log2EEW(Log2EEW) {}
|
|
|
|
OperandInfo(unsigned Log2EEW) : Log2EEW(Log2EEW) {}
|
|
|
|
OperandInfo() = delete;
|
|
|
|
static bool EMULAndEEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
|
|
return A.Log2EEW == B.Log2EEW && A.EMUL == B.EMUL;
|
|
}
|
|
|
|
static bool EEWAreEqual(const OperandInfo &A, const OperandInfo &B) {
|
|
return A.Log2EEW == B.Log2EEW;
|
|
}
|
|
|
|
void print(raw_ostream &OS) const {
|
|
if (EMUL) {
|
|
OS << "EMUL: m";
|
|
if (EMUL->second)
|
|
OS << "f";
|
|
OS << EMUL->first;
|
|
} else
|
|
OS << "EMUL: unknown\n";
|
|
OS << ", EEW: " << (1 << Log2EEW);
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
char RISCVVLOptimizer::ID = 0;
|
|
INITIALIZE_PASS_BEGIN(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
|
|
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
|
|
INITIALIZE_PASS_END(RISCVVLOptimizer, DEBUG_TYPE, PASS_NAME, false, false)
|
|
|
|
FunctionPass *llvm::createRISCVVLOptimizerPass() {
|
|
return new RISCVVLOptimizer();
|
|
}
|
|
|
|
/// Return true if R is a physical or virtual vector register, false otherwise.
|
|
static bool isVectorRegClass(Register R, const MachineRegisterInfo *MRI) {
|
|
if (R.isPhysical())
|
|
return RISCV::VRRegClass.contains(R);
|
|
const TargetRegisterClass *RC = MRI->getRegClass(R);
|
|
return RISCVRI::isVRegClass(RC->TSFlags);
|
|
}
|
|
|
|
LLVM_ATTRIBUTE_UNUSED
|
|
static raw_ostream &operator<<(raw_ostream &OS, const OperandInfo &OI) {
|
|
OI.print(OS);
|
|
return OS;
|
|
}
|
|
|
|
LLVM_ATTRIBUTE_UNUSED
|
|
static raw_ostream &operator<<(raw_ostream &OS,
|
|
const std::optional<OperandInfo> &OI) {
|
|
if (OI)
|
|
OI->print(OS);
|
|
else
|
|
OS << "nullopt";
|
|
return OS;
|
|
}
|
|
|
|
/// Return EMUL = (EEW / SEW) * LMUL where EEW comes from Log2EEW and LMUL and
|
|
/// SEW are from the TSFlags of MI.
|
|
static std::pair<unsigned, bool>
|
|
getEMULEqualsEEWDivSEWTimesLMUL(unsigned Log2EEW, const MachineInstr &MI) {
|
|
RISCVVType::VLMUL MIVLMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
|
|
auto [MILMUL, MILMULIsFractional] = RISCVVType::decodeVLMUL(MIVLMUL);
|
|
unsigned MILog2SEW =
|
|
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
|
|
|
|
// Mask instructions will have 0 as the SEW operand. But the LMUL of these
|
|
// instructions is calculated is as if the SEW operand was 3 (e8).
|
|
if (MILog2SEW == 0)
|
|
MILog2SEW = 3;
|
|
|
|
unsigned MISEW = 1 << MILog2SEW;
|
|
|
|
unsigned EEW = 1 << Log2EEW;
|
|
// Calculate (EEW/SEW)*LMUL preserving fractions less than 1. Use GCD
|
|
// to put fraction in simplest form.
|
|
unsigned Num = EEW, Denom = MISEW;
|
|
int GCD = MILMULIsFractional ? std::gcd(Num, Denom * MILMUL)
|
|
: std::gcd(Num * MILMUL, Denom);
|
|
Num = MILMULIsFractional ? Num / GCD : Num * MILMUL / GCD;
|
|
Denom = MILMULIsFractional ? Denom * MILMUL / GCD : Denom / GCD;
|
|
return std::make_pair(Num > Denom ? Num : Denom, Denom > Num);
|
|
}
|
|
|
|
/// Dest has EEW=SEW. Source EEW=SEW/Factor (i.e. F2 => EEW/2).
|
|
/// SEW comes from TSFlags of MI.
|
|
static unsigned getIntegerExtensionOperandEEW(unsigned Factor,
|
|
const MachineInstr &MI,
|
|
const MachineOperand &MO) {
|
|
unsigned MILog2SEW =
|
|
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
|
|
|
|
if (MO.getOperandNo() == 0)
|
|
return MILog2SEW;
|
|
|
|
unsigned MISEW = 1 << MILog2SEW;
|
|
unsigned EEW = MISEW / Factor;
|
|
unsigned Log2EEW = Log2_32(EEW);
|
|
|
|
return Log2EEW;
|
|
}
|
|
|
|
/// Check whether MO is a mask operand of MI.
|
|
static bool isMaskOperand(const MachineInstr &MI, const MachineOperand &MO,
|
|
const MachineRegisterInfo *MRI) {
|
|
|
|
if (!MO.isReg() || !isVectorRegClass(MO.getReg(), MRI))
|
|
return false;
|
|
|
|
const MCInstrDesc &Desc = MI.getDesc();
|
|
return Desc.operands()[MO.getOperandNo()].RegClass == RISCV::VMV0RegClassID;
|
|
}
|
|
|
|
static std::optional<unsigned>
|
|
getOperandLog2EEW(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
|
|
const MachineInstr &MI = *MO.getParent();
|
|
const RISCVVPseudosTable::PseudoInfo *RVV =
|
|
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
|
|
assert(RVV && "Could not find MI in PseudoTable");
|
|
|
|
// MI has a SEW associated with it. The RVV specification defines
|
|
// the EEW of each operand and definition in relation to MI.SEW.
|
|
unsigned MILog2SEW =
|
|
MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
|
|
|
|
const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MI.getDesc());
|
|
const bool IsTied = RISCVII::isTiedPseudo(MI.getDesc().TSFlags);
|
|
|
|
bool IsMODef = MO.getOperandNo() == 0 ||
|
|
(HasPassthru && MO.getOperandNo() == MI.getNumExplicitDefs());
|
|
|
|
// All mask operands have EEW=1
|
|
if (isMaskOperand(MI, MO, MRI))
|
|
return 0;
|
|
|
|
// switch against BaseInstr to reduce number of cases that need to be
|
|
// considered.
|
|
switch (RVV->BaseInstr) {
|
|
|
|
// 6. Configuration-Setting Instructions
|
|
// Configuration setting instructions do not read or write vector registers
|
|
case RISCV::VSETIVLI:
|
|
case RISCV::VSETVL:
|
|
case RISCV::VSETVLI:
|
|
llvm_unreachable("Configuration setting instructions do not read or write "
|
|
"vector registers");
|
|
|
|
// Vector Loads and Stores
|
|
// Vector Unit-Stride Instructions
|
|
// Vector Strided Instructions
|
|
/// Dest EEW encoded in the instruction
|
|
case RISCV::VLM_V:
|
|
case RISCV::VSM_V:
|
|
return 0;
|
|
case RISCV::VLE8_V:
|
|
case RISCV::VSE8_V:
|
|
case RISCV::VLSE8_V:
|
|
case RISCV::VSSE8_V:
|
|
return 3;
|
|
case RISCV::VLE16_V:
|
|
case RISCV::VSE16_V:
|
|
case RISCV::VLSE16_V:
|
|
case RISCV::VSSE16_V:
|
|
return 4;
|
|
case RISCV::VLE32_V:
|
|
case RISCV::VSE32_V:
|
|
case RISCV::VLSE32_V:
|
|
case RISCV::VSSE32_V:
|
|
return 5;
|
|
case RISCV::VLE64_V:
|
|
case RISCV::VSE64_V:
|
|
case RISCV::VLSE64_V:
|
|
case RISCV::VSSE64_V:
|
|
return 6;
|
|
|
|
// Vector Indexed Instructions
|
|
// vs(o|u)xei<eew>.v
|
|
// Dest/Data (operand 0) EEW=SEW. Source EEW=<eew>.
|
|
case RISCV::VLUXEI8_V:
|
|
case RISCV::VLOXEI8_V:
|
|
case RISCV::VSUXEI8_V:
|
|
case RISCV::VSOXEI8_V: {
|
|
if (MO.getOperandNo() == 0)
|
|
return MILog2SEW;
|
|
return 3;
|
|
}
|
|
case RISCV::VLUXEI16_V:
|
|
case RISCV::VLOXEI16_V:
|
|
case RISCV::VSUXEI16_V:
|
|
case RISCV::VSOXEI16_V: {
|
|
if (MO.getOperandNo() == 0)
|
|
return MILog2SEW;
|
|
return 4;
|
|
}
|
|
case RISCV::VLUXEI32_V:
|
|
case RISCV::VLOXEI32_V:
|
|
case RISCV::VSUXEI32_V:
|
|
case RISCV::VSOXEI32_V: {
|
|
if (MO.getOperandNo() == 0)
|
|
return MILog2SEW;
|
|
return 5;
|
|
}
|
|
case RISCV::VLUXEI64_V:
|
|
case RISCV::VLOXEI64_V:
|
|
case RISCV::VSUXEI64_V:
|
|
case RISCV::VSOXEI64_V: {
|
|
if (MO.getOperandNo() == 0)
|
|
return MILog2SEW;
|
|
return 6;
|
|
}
|
|
|
|
// Vector Integer Arithmetic Instructions
|
|
// Vector Single-Width Integer Add and Subtract
|
|
case RISCV::VADD_VI:
|
|
case RISCV::VADD_VV:
|
|
case RISCV::VADD_VX:
|
|
case RISCV::VSUB_VV:
|
|
case RISCV::VSUB_VX:
|
|
case RISCV::VRSUB_VI:
|
|
case RISCV::VRSUB_VX:
|
|
// Vector Bitwise Logical Instructions
|
|
// Vector Single-Width Shift Instructions
|
|
// EEW=SEW.
|
|
case RISCV::VAND_VI:
|
|
case RISCV::VAND_VV:
|
|
case RISCV::VAND_VX:
|
|
case RISCV::VOR_VI:
|
|
case RISCV::VOR_VV:
|
|
case RISCV::VOR_VX:
|
|
case RISCV::VXOR_VI:
|
|
case RISCV::VXOR_VV:
|
|
case RISCV::VXOR_VX:
|
|
case RISCV::VSLL_VI:
|
|
case RISCV::VSLL_VV:
|
|
case RISCV::VSLL_VX:
|
|
case RISCV::VSRL_VI:
|
|
case RISCV::VSRL_VV:
|
|
case RISCV::VSRL_VX:
|
|
case RISCV::VSRA_VI:
|
|
case RISCV::VSRA_VV:
|
|
case RISCV::VSRA_VX:
|
|
// Vector Integer Min/Max Instructions
|
|
// EEW=SEW.
|
|
case RISCV::VMINU_VV:
|
|
case RISCV::VMINU_VX:
|
|
case RISCV::VMIN_VV:
|
|
case RISCV::VMIN_VX:
|
|
case RISCV::VMAXU_VV:
|
|
case RISCV::VMAXU_VX:
|
|
case RISCV::VMAX_VV:
|
|
case RISCV::VMAX_VX:
|
|
// Vector Single-Width Integer Multiply Instructions
|
|
// Source and Dest EEW=SEW.
|
|
case RISCV::VMUL_VV:
|
|
case RISCV::VMUL_VX:
|
|
case RISCV::VMULH_VV:
|
|
case RISCV::VMULH_VX:
|
|
case RISCV::VMULHU_VV:
|
|
case RISCV::VMULHU_VX:
|
|
case RISCV::VMULHSU_VV:
|
|
case RISCV::VMULHSU_VX:
|
|
// Vector Integer Divide Instructions
|
|
// EEW=SEW.
|
|
case RISCV::VDIVU_VV:
|
|
case RISCV::VDIVU_VX:
|
|
case RISCV::VDIV_VV:
|
|
case RISCV::VDIV_VX:
|
|
case RISCV::VREMU_VV:
|
|
case RISCV::VREMU_VX:
|
|
case RISCV::VREM_VV:
|
|
case RISCV::VREM_VX:
|
|
// Vector Single-Width Integer Multiply-Add Instructions
|
|
// EEW=SEW.
|
|
case RISCV::VMACC_VV:
|
|
case RISCV::VMACC_VX:
|
|
case RISCV::VNMSAC_VV:
|
|
case RISCV::VNMSAC_VX:
|
|
case RISCV::VMADD_VV:
|
|
case RISCV::VMADD_VX:
|
|
case RISCV::VNMSUB_VV:
|
|
case RISCV::VNMSUB_VX:
|
|
// Vector Integer Merge Instructions
|
|
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
|
|
// EEW=SEW, except the mask operand has EEW=1. Mask operand is handled
|
|
// before this switch.
|
|
case RISCV::VMERGE_VIM:
|
|
case RISCV::VMERGE_VVM:
|
|
case RISCV::VMERGE_VXM:
|
|
case RISCV::VADC_VIM:
|
|
case RISCV::VADC_VVM:
|
|
case RISCV::VADC_VXM:
|
|
case RISCV::VSBC_VVM:
|
|
case RISCV::VSBC_VXM:
|
|
// Vector Integer Move Instructions
|
|
// Vector Fixed-Point Arithmetic Instructions
|
|
// Vector Single-Width Saturating Add and Subtract
|
|
// Vector Single-Width Averaging Add and Subtract
|
|
// EEW=SEW.
|
|
case RISCV::VMV_V_I:
|
|
case RISCV::VMV_V_V:
|
|
case RISCV::VMV_V_X:
|
|
case RISCV::VSADDU_VI:
|
|
case RISCV::VSADDU_VV:
|
|
case RISCV::VSADDU_VX:
|
|
case RISCV::VSADD_VI:
|
|
case RISCV::VSADD_VV:
|
|
case RISCV::VSADD_VX:
|
|
case RISCV::VSSUBU_VV:
|
|
case RISCV::VSSUBU_VX:
|
|
case RISCV::VSSUB_VV:
|
|
case RISCV::VSSUB_VX:
|
|
case RISCV::VAADDU_VV:
|
|
case RISCV::VAADDU_VX:
|
|
case RISCV::VAADD_VV:
|
|
case RISCV::VAADD_VX:
|
|
case RISCV::VASUBU_VV:
|
|
case RISCV::VASUBU_VX:
|
|
case RISCV::VASUB_VV:
|
|
case RISCV::VASUB_VX:
|
|
// Vector Single-Width Fractional Multiply with Rounding and Saturation
|
|
// EEW=SEW. The instruction produces 2*SEW product internally but
|
|
// saturates to fit into SEW bits.
|
|
case RISCV::VSMUL_VV:
|
|
case RISCV::VSMUL_VX:
|
|
// Vector Single-Width Scaling Shift Instructions
|
|
// EEW=SEW.
|
|
case RISCV::VSSRL_VI:
|
|
case RISCV::VSSRL_VV:
|
|
case RISCV::VSSRL_VX:
|
|
case RISCV::VSSRA_VI:
|
|
case RISCV::VSSRA_VV:
|
|
case RISCV::VSSRA_VX:
|
|
// Vector Permutation Instructions
|
|
// Integer Scalar Move Instructions
|
|
// Floating-Point Scalar Move Instructions
|
|
// EEW=SEW.
|
|
case RISCV::VMV_X_S:
|
|
case RISCV::VMV_S_X:
|
|
case RISCV::VFMV_F_S:
|
|
case RISCV::VFMV_S_F:
|
|
// Vector Slide Instructions
|
|
// EEW=SEW.
|
|
case RISCV::VSLIDEUP_VI:
|
|
case RISCV::VSLIDEUP_VX:
|
|
case RISCV::VSLIDEDOWN_VI:
|
|
case RISCV::VSLIDEDOWN_VX:
|
|
case RISCV::VSLIDE1UP_VX:
|
|
case RISCV::VFSLIDE1UP_VF:
|
|
case RISCV::VSLIDE1DOWN_VX:
|
|
case RISCV::VFSLIDE1DOWN_VF:
|
|
// Vector Register Gather Instructions
|
|
// EEW=SEW. For mask operand, EEW=1.
|
|
case RISCV::VRGATHER_VI:
|
|
case RISCV::VRGATHER_VV:
|
|
case RISCV::VRGATHER_VX:
|
|
// Vector Compress Instruction
|
|
// EEW=SEW.
|
|
case RISCV::VCOMPRESS_VM:
|
|
// Vector Element Index Instruction
|
|
case RISCV::VID_V:
|
|
// Vector Single-Width Floating-Point Add/Subtract Instructions
|
|
case RISCV::VFADD_VF:
|
|
case RISCV::VFADD_VV:
|
|
case RISCV::VFSUB_VF:
|
|
case RISCV::VFSUB_VV:
|
|
case RISCV::VFRSUB_VF:
|
|
// Vector Single-Width Floating-Point Multiply/Divide Instructions
|
|
case RISCV::VFMUL_VF:
|
|
case RISCV::VFMUL_VV:
|
|
case RISCV::VFDIV_VF:
|
|
case RISCV::VFDIV_VV:
|
|
case RISCV::VFRDIV_VF:
|
|
// Vector Single-Width Floating-Point Fused Multiply-Add Instructions
|
|
case RISCV::VFMACC_VV:
|
|
case RISCV::VFMACC_VF:
|
|
case RISCV::VFNMACC_VV:
|
|
case RISCV::VFNMACC_VF:
|
|
case RISCV::VFMSAC_VV:
|
|
case RISCV::VFMSAC_VF:
|
|
case RISCV::VFNMSAC_VV:
|
|
case RISCV::VFNMSAC_VF:
|
|
case RISCV::VFMADD_VV:
|
|
case RISCV::VFMADD_VF:
|
|
case RISCV::VFNMADD_VV:
|
|
case RISCV::VFNMADD_VF:
|
|
case RISCV::VFMSUB_VV:
|
|
case RISCV::VFMSUB_VF:
|
|
case RISCV::VFNMSUB_VV:
|
|
case RISCV::VFNMSUB_VF:
|
|
// Vector Floating-Point Square-Root Instruction
|
|
case RISCV::VFSQRT_V:
|
|
// Vector Floating-Point Reciprocal Square-Root Estimate Instruction
|
|
case RISCV::VFRSQRT7_V:
|
|
// Vector Floating-Point Reciprocal Estimate Instruction
|
|
case RISCV::VFREC7_V:
|
|
// Vector Floating-Point MIN/MAX Instructions
|
|
case RISCV::VFMIN_VF:
|
|
case RISCV::VFMIN_VV:
|
|
case RISCV::VFMAX_VF:
|
|
case RISCV::VFMAX_VV:
|
|
// Vector Floating-Point Sign-Injection Instructions
|
|
case RISCV::VFSGNJ_VF:
|
|
case RISCV::VFSGNJ_VV:
|
|
case RISCV::VFSGNJN_VV:
|
|
case RISCV::VFSGNJN_VF:
|
|
case RISCV::VFSGNJX_VF:
|
|
case RISCV::VFSGNJX_VV:
|
|
// Vector Floating-Point Classify Instruction
|
|
case RISCV::VFCLASS_V:
|
|
// Vector Floating-Point Move Instruction
|
|
case RISCV::VFMV_V_F:
|
|
// Single-Width Floating-Point/Integer Type-Convert Instructions
|
|
case RISCV::VFCVT_XU_F_V:
|
|
case RISCV::VFCVT_X_F_V:
|
|
case RISCV::VFCVT_RTZ_XU_F_V:
|
|
case RISCV::VFCVT_RTZ_X_F_V:
|
|
case RISCV::VFCVT_F_XU_V:
|
|
case RISCV::VFCVT_F_X_V:
|
|
// Vector Floating-Point Merge Instruction
|
|
case RISCV::VFMERGE_VFM:
|
|
// Vector count population in mask vcpop.m
|
|
// vfirst find-first-set mask bit
|
|
case RISCV::VCPOP_M:
|
|
case RISCV::VFIRST_M:
|
|
return MILog2SEW;
|
|
|
|
// Vector Widening Integer Add/Subtract
|
|
// Def uses EEW=2*SEW . Operands use EEW=SEW.
|
|
case RISCV::VWADDU_VV:
|
|
case RISCV::VWADDU_VX:
|
|
case RISCV::VWSUBU_VV:
|
|
case RISCV::VWSUBU_VX:
|
|
case RISCV::VWADD_VV:
|
|
case RISCV::VWADD_VX:
|
|
case RISCV::VWSUB_VV:
|
|
case RISCV::VWSUB_VX:
|
|
case RISCV::VWSLL_VI:
|
|
// Vector Widening Integer Multiply Instructions
|
|
// Destination EEW=2*SEW. Source EEW=SEW.
|
|
case RISCV::VWMUL_VV:
|
|
case RISCV::VWMUL_VX:
|
|
case RISCV::VWMULSU_VV:
|
|
case RISCV::VWMULSU_VX:
|
|
case RISCV::VWMULU_VV:
|
|
case RISCV::VWMULU_VX:
|
|
// Vector Widening Integer Multiply-Add Instructions
|
|
// Destination EEW=2*SEW. Source EEW=SEW.
|
|
// A SEW-bit*SEW-bit multiply of the sources forms a 2*SEW-bit value, which
|
|
// is then added to the 2*SEW-bit Dest. These instructions never have a
|
|
// passthru operand.
|
|
case RISCV::VWMACCU_VV:
|
|
case RISCV::VWMACCU_VX:
|
|
case RISCV::VWMACC_VV:
|
|
case RISCV::VWMACC_VX:
|
|
case RISCV::VWMACCSU_VV:
|
|
case RISCV::VWMACCSU_VX:
|
|
case RISCV::VWMACCUS_VX:
|
|
// Vector Widening Floating-Point Fused Multiply-Add Instructions
|
|
case RISCV::VFWMACC_VF:
|
|
case RISCV::VFWMACC_VV:
|
|
case RISCV::VFWNMACC_VF:
|
|
case RISCV::VFWNMACC_VV:
|
|
case RISCV::VFWMSAC_VF:
|
|
case RISCV::VFWMSAC_VV:
|
|
case RISCV::VFWNMSAC_VF:
|
|
case RISCV::VFWNMSAC_VV:
|
|
case RISCV::VFWMACCBF16_VV:
|
|
case RISCV::VFWMACCBF16_VF:
|
|
// Vector Widening Floating-Point Add/Subtract Instructions
|
|
// Dest EEW=2*SEW. Source EEW=SEW.
|
|
case RISCV::VFWADD_VV:
|
|
case RISCV::VFWADD_VF:
|
|
case RISCV::VFWSUB_VV:
|
|
case RISCV::VFWSUB_VF:
|
|
// Vector Widening Floating-Point Multiply
|
|
case RISCV::VFWMUL_VF:
|
|
case RISCV::VFWMUL_VV:
|
|
// Widening Floating-Point/Integer Type-Convert Instructions
|
|
case RISCV::VFWCVT_XU_F_V:
|
|
case RISCV::VFWCVT_X_F_V:
|
|
case RISCV::VFWCVT_RTZ_XU_F_V:
|
|
case RISCV::VFWCVT_RTZ_X_F_V:
|
|
case RISCV::VFWCVT_F_XU_V:
|
|
case RISCV::VFWCVT_F_X_V:
|
|
case RISCV::VFWCVT_F_F_V:
|
|
case RISCV::VFWCVTBF16_F_F_V:
|
|
return IsMODef ? MILog2SEW + 1 : MILog2SEW;
|
|
|
|
// Def and Op1 uses EEW=2*SEW. Op2 uses EEW=SEW.
|
|
case RISCV::VWADDU_WV:
|
|
case RISCV::VWADDU_WX:
|
|
case RISCV::VWSUBU_WV:
|
|
case RISCV::VWSUBU_WX:
|
|
case RISCV::VWADD_WV:
|
|
case RISCV::VWADD_WX:
|
|
case RISCV::VWSUB_WV:
|
|
case RISCV::VWSUB_WX:
|
|
// Vector Widening Floating-Point Add/Subtract Instructions
|
|
case RISCV::VFWADD_WF:
|
|
case RISCV::VFWADD_WV:
|
|
case RISCV::VFWSUB_WF:
|
|
case RISCV::VFWSUB_WV: {
|
|
bool IsOp1 = (HasPassthru && !IsTied) ? MO.getOperandNo() == 2
|
|
: MO.getOperandNo() == 1;
|
|
bool TwoTimes = IsMODef || IsOp1;
|
|
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
|
|
}
|
|
|
|
// Vector Integer Extension
|
|
case RISCV::VZEXT_VF2:
|
|
case RISCV::VSEXT_VF2:
|
|
return getIntegerExtensionOperandEEW(2, MI, MO);
|
|
case RISCV::VZEXT_VF4:
|
|
case RISCV::VSEXT_VF4:
|
|
return getIntegerExtensionOperandEEW(4, MI, MO);
|
|
case RISCV::VZEXT_VF8:
|
|
case RISCV::VSEXT_VF8:
|
|
return getIntegerExtensionOperandEEW(8, MI, MO);
|
|
|
|
// Vector Narrowing Integer Right Shift Instructions
|
|
// Destination EEW=SEW, Op 1 has EEW=2*SEW. Op2 has EEW=SEW
|
|
case RISCV::VNSRL_WX:
|
|
case RISCV::VNSRL_WI:
|
|
case RISCV::VNSRL_WV:
|
|
case RISCV::VNSRA_WI:
|
|
case RISCV::VNSRA_WV:
|
|
case RISCV::VNSRA_WX:
|
|
// Vector Narrowing Fixed-Point Clip Instructions
|
|
// Destination and Op1 EEW=SEW. Op2 EEW=2*SEW.
|
|
case RISCV::VNCLIPU_WI:
|
|
case RISCV::VNCLIPU_WV:
|
|
case RISCV::VNCLIPU_WX:
|
|
case RISCV::VNCLIP_WI:
|
|
case RISCV::VNCLIP_WV:
|
|
case RISCV::VNCLIP_WX:
|
|
// Narrowing Floating-Point/Integer Type-Convert Instructions
|
|
case RISCV::VFNCVT_XU_F_W:
|
|
case RISCV::VFNCVT_X_F_W:
|
|
case RISCV::VFNCVT_RTZ_XU_F_W:
|
|
case RISCV::VFNCVT_RTZ_X_F_W:
|
|
case RISCV::VFNCVT_F_XU_W:
|
|
case RISCV::VFNCVT_F_X_W:
|
|
case RISCV::VFNCVT_F_F_W:
|
|
case RISCV::VFNCVT_ROD_F_F_W:
|
|
case RISCV::VFNCVTBF16_F_F_W: {
|
|
assert(!IsTied);
|
|
bool IsOp1 = HasPassthru ? MO.getOperandNo() == 2 : MO.getOperandNo() == 1;
|
|
bool TwoTimes = IsOp1;
|
|
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
|
|
}
|
|
|
|
// Vector Mask Instructions
|
|
// Vector Mask-Register Logical Instructions
|
|
// vmsbf.m set-before-first mask bit
|
|
// vmsif.m set-including-first mask bit
|
|
// vmsof.m set-only-first mask bit
|
|
// EEW=1
|
|
// We handle the cases when operand is a v0 mask operand above the switch,
|
|
// but these instructions may use non-v0 mask operands and need to be handled
|
|
// specifically.
|
|
case RISCV::VMAND_MM:
|
|
case RISCV::VMNAND_MM:
|
|
case RISCV::VMANDN_MM:
|
|
case RISCV::VMXOR_MM:
|
|
case RISCV::VMOR_MM:
|
|
case RISCV::VMNOR_MM:
|
|
case RISCV::VMORN_MM:
|
|
case RISCV::VMXNOR_MM:
|
|
case RISCV::VMSBF_M:
|
|
case RISCV::VMSIF_M:
|
|
case RISCV::VMSOF_M: {
|
|
return MILog2SEW;
|
|
}
|
|
|
|
// Vector Iota Instruction
|
|
// EEW=SEW, except the mask operand has EEW=1. Mask operand is not handled
|
|
// before this switch.
|
|
case RISCV::VIOTA_M: {
|
|
if (IsMODef || MO.getOperandNo() == 1)
|
|
return MILog2SEW;
|
|
return 0;
|
|
}
|
|
|
|
// Vector Integer Compare Instructions
|
|
// Dest EEW=1. Source EEW=SEW.
|
|
case RISCV::VMSEQ_VI:
|
|
case RISCV::VMSEQ_VV:
|
|
case RISCV::VMSEQ_VX:
|
|
case RISCV::VMSNE_VI:
|
|
case RISCV::VMSNE_VV:
|
|
case RISCV::VMSNE_VX:
|
|
case RISCV::VMSLTU_VV:
|
|
case RISCV::VMSLTU_VX:
|
|
case RISCV::VMSLT_VV:
|
|
case RISCV::VMSLT_VX:
|
|
case RISCV::VMSLEU_VV:
|
|
case RISCV::VMSLEU_VI:
|
|
case RISCV::VMSLEU_VX:
|
|
case RISCV::VMSLE_VV:
|
|
case RISCV::VMSLE_VI:
|
|
case RISCV::VMSLE_VX:
|
|
case RISCV::VMSGTU_VI:
|
|
case RISCV::VMSGTU_VX:
|
|
case RISCV::VMSGT_VI:
|
|
case RISCV::VMSGT_VX:
|
|
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
|
|
// Dest EEW=1. Source EEW=SEW. Mask source operand handled above this switch.
|
|
case RISCV::VMADC_VIM:
|
|
case RISCV::VMADC_VVM:
|
|
case RISCV::VMADC_VXM:
|
|
case RISCV::VMSBC_VVM:
|
|
case RISCV::VMSBC_VXM:
|
|
// Dest EEW=1. Source EEW=SEW.
|
|
case RISCV::VMADC_VV:
|
|
case RISCV::VMADC_VI:
|
|
case RISCV::VMADC_VX:
|
|
case RISCV::VMSBC_VV:
|
|
case RISCV::VMSBC_VX:
|
|
// 13.13. Vector Floating-Point Compare Instructions
|
|
// Dest EEW=1. Source EEW=SEW
|
|
case RISCV::VMFEQ_VF:
|
|
case RISCV::VMFEQ_VV:
|
|
case RISCV::VMFNE_VF:
|
|
case RISCV::VMFNE_VV:
|
|
case RISCV::VMFLT_VF:
|
|
case RISCV::VMFLT_VV:
|
|
case RISCV::VMFLE_VF:
|
|
case RISCV::VMFLE_VV:
|
|
case RISCV::VMFGT_VF:
|
|
case RISCV::VMFGE_VF: {
|
|
if (IsMODef)
|
|
return 0;
|
|
return MILog2SEW;
|
|
}
|
|
|
|
// Vector Reduction Operations
|
|
// Vector Single-Width Integer Reduction Instructions
|
|
case RISCV::VREDAND_VS:
|
|
case RISCV::VREDMAX_VS:
|
|
case RISCV::VREDMAXU_VS:
|
|
case RISCV::VREDMIN_VS:
|
|
case RISCV::VREDMINU_VS:
|
|
case RISCV::VREDOR_VS:
|
|
case RISCV::VREDSUM_VS:
|
|
case RISCV::VREDXOR_VS:
|
|
// Vector Single-Width Floating-Point Reduction Instructions
|
|
case RISCV::VFREDMAX_VS:
|
|
case RISCV::VFREDMIN_VS:
|
|
case RISCV::VFREDOSUM_VS:
|
|
case RISCV::VFREDUSUM_VS: {
|
|
return MILog2SEW;
|
|
}
|
|
|
|
// Vector Widening Integer Reduction Instructions
|
|
// The Dest and VS1 read only element 0 for the vector register. Return
|
|
// 2*EEW for these. VS2 has EEW=SEW and EMUL=LMUL.
|
|
case RISCV::VWREDSUM_VS:
|
|
case RISCV::VWREDSUMU_VS:
|
|
// Vector Widening Floating-Point Reduction Instructions
|
|
case RISCV::VFWREDOSUM_VS:
|
|
case RISCV::VFWREDUSUM_VS: {
|
|
bool TwoTimes = IsMODef || MO.getOperandNo() == 3;
|
|
return TwoTimes ? MILog2SEW + 1 : MILog2SEW;
|
|
}
|
|
|
|
default:
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
static std::optional<OperandInfo>
|
|
getOperandInfo(const MachineOperand &MO, const MachineRegisterInfo *MRI) {
|
|
const MachineInstr &MI = *MO.getParent();
|
|
const RISCVVPseudosTable::PseudoInfo *RVV =
|
|
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
|
|
assert(RVV && "Could not find MI in PseudoTable");
|
|
|
|
std::optional<unsigned> Log2EEW = getOperandLog2EEW(MO, MRI);
|
|
if (!Log2EEW)
|
|
return std::nullopt;
|
|
|
|
switch (RVV->BaseInstr) {
|
|
// Vector Reduction Operations
|
|
// Vector Single-Width Integer Reduction Instructions
|
|
// Vector Widening Integer Reduction Instructions
|
|
// Vector Widening Floating-Point Reduction Instructions
|
|
// The Dest and VS1 only read element 0 of the vector register. Return just
|
|
// the EEW for these.
|
|
case RISCV::VREDAND_VS:
|
|
case RISCV::VREDMAX_VS:
|
|
case RISCV::VREDMAXU_VS:
|
|
case RISCV::VREDMIN_VS:
|
|
case RISCV::VREDMINU_VS:
|
|
case RISCV::VREDOR_VS:
|
|
case RISCV::VREDSUM_VS:
|
|
case RISCV::VREDXOR_VS:
|
|
case RISCV::VWREDSUM_VS:
|
|
case RISCV::VWREDSUMU_VS:
|
|
case RISCV::VFWREDOSUM_VS:
|
|
case RISCV::VFWREDUSUM_VS:
|
|
if (MO.getOperandNo() != 2)
|
|
return OperandInfo(*Log2EEW);
|
|
break;
|
|
};
|
|
|
|
// All others have EMUL=EEW/SEW*LMUL
|
|
return OperandInfo(getEMULEqualsEEWDivSEWTimesLMUL(*Log2EEW, MI), *Log2EEW);
|
|
}
|
|
|
|
/// Return true if this optimization should consider MI for VL reduction. This
|
|
/// white-list approach simplifies this optimization for instructions that may
|
|
/// have more complex semantics with relation to how it uses VL.
|
|
static bool isSupportedInstr(const MachineInstr &MI) {
|
|
const RISCVVPseudosTable::PseudoInfo *RVV =
|
|
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
|
|
|
|
if (!RVV)
|
|
return false;
|
|
|
|
switch (RVV->BaseInstr) {
|
|
// Vector Unit-Stride Instructions
|
|
// Vector Strided Instructions
|
|
case RISCV::VLM_V:
|
|
case RISCV::VLE8_V:
|
|
case RISCV::VLSE8_V:
|
|
case RISCV::VLE16_V:
|
|
case RISCV::VLSE16_V:
|
|
case RISCV::VLE32_V:
|
|
case RISCV::VLSE32_V:
|
|
case RISCV::VLE64_V:
|
|
case RISCV::VLSE64_V:
|
|
// Vector Indexed Instructions
|
|
case RISCV::VLUXEI8_V:
|
|
case RISCV::VLOXEI8_V:
|
|
case RISCV::VLUXEI16_V:
|
|
case RISCV::VLOXEI16_V:
|
|
case RISCV::VLUXEI32_V:
|
|
case RISCV::VLOXEI32_V:
|
|
case RISCV::VLUXEI64_V:
|
|
case RISCV::VLOXEI64_V: {
|
|
for (const MachineMemOperand *MMO : MI.memoperands())
|
|
if (MMO->isVolatile())
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
// Vector Single-Width Integer Add and Subtract
|
|
case RISCV::VADD_VI:
|
|
case RISCV::VADD_VV:
|
|
case RISCV::VADD_VX:
|
|
case RISCV::VSUB_VV:
|
|
case RISCV::VSUB_VX:
|
|
case RISCV::VRSUB_VI:
|
|
case RISCV::VRSUB_VX:
|
|
// Vector Bitwise Logical Instructions
|
|
// Vector Single-Width Shift Instructions
|
|
case RISCV::VAND_VI:
|
|
case RISCV::VAND_VV:
|
|
case RISCV::VAND_VX:
|
|
case RISCV::VOR_VI:
|
|
case RISCV::VOR_VV:
|
|
case RISCV::VOR_VX:
|
|
case RISCV::VXOR_VI:
|
|
case RISCV::VXOR_VV:
|
|
case RISCV::VXOR_VX:
|
|
case RISCV::VSLL_VI:
|
|
case RISCV::VSLL_VV:
|
|
case RISCV::VSLL_VX:
|
|
case RISCV::VSRL_VI:
|
|
case RISCV::VSRL_VV:
|
|
case RISCV::VSRL_VX:
|
|
case RISCV::VSRA_VI:
|
|
case RISCV::VSRA_VV:
|
|
case RISCV::VSRA_VX:
|
|
// Vector Widening Integer Add/Subtract
|
|
case RISCV::VWADDU_VV:
|
|
case RISCV::VWADDU_VX:
|
|
case RISCV::VWSUBU_VV:
|
|
case RISCV::VWSUBU_VX:
|
|
case RISCV::VWADD_VV:
|
|
case RISCV::VWADD_VX:
|
|
case RISCV::VWSUB_VV:
|
|
case RISCV::VWSUB_VX:
|
|
case RISCV::VWADDU_WV:
|
|
case RISCV::VWADDU_WX:
|
|
case RISCV::VWSUBU_WV:
|
|
case RISCV::VWSUBU_WX:
|
|
case RISCV::VWADD_WV:
|
|
case RISCV::VWADD_WX:
|
|
case RISCV::VWSUB_WV:
|
|
case RISCV::VWSUB_WX:
|
|
// Vector Integer Extension
|
|
case RISCV::VZEXT_VF2:
|
|
case RISCV::VSEXT_VF2:
|
|
case RISCV::VZEXT_VF4:
|
|
case RISCV::VSEXT_VF4:
|
|
case RISCV::VZEXT_VF8:
|
|
case RISCV::VSEXT_VF8:
|
|
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
|
|
// FIXME: Add support
|
|
case RISCV::VMADC_VV:
|
|
case RISCV::VMADC_VI:
|
|
case RISCV::VMADC_VX:
|
|
case RISCV::VMSBC_VV:
|
|
case RISCV::VMSBC_VX:
|
|
// Vector Narrowing Integer Right Shift Instructions
|
|
case RISCV::VNSRL_WX:
|
|
case RISCV::VNSRL_WI:
|
|
case RISCV::VNSRL_WV:
|
|
case RISCV::VNSRA_WI:
|
|
case RISCV::VNSRA_WV:
|
|
case RISCV::VNSRA_WX:
|
|
// Vector Integer Compare Instructions
|
|
case RISCV::VMSEQ_VI:
|
|
case RISCV::VMSEQ_VV:
|
|
case RISCV::VMSEQ_VX:
|
|
case RISCV::VMSNE_VI:
|
|
case RISCV::VMSNE_VV:
|
|
case RISCV::VMSNE_VX:
|
|
case RISCV::VMSLTU_VV:
|
|
case RISCV::VMSLTU_VX:
|
|
case RISCV::VMSLT_VV:
|
|
case RISCV::VMSLT_VX:
|
|
case RISCV::VMSLEU_VV:
|
|
case RISCV::VMSLEU_VI:
|
|
case RISCV::VMSLEU_VX:
|
|
case RISCV::VMSLE_VV:
|
|
case RISCV::VMSLE_VI:
|
|
case RISCV::VMSLE_VX:
|
|
case RISCV::VMSGTU_VI:
|
|
case RISCV::VMSGTU_VX:
|
|
case RISCV::VMSGT_VI:
|
|
case RISCV::VMSGT_VX:
|
|
// Vector Integer Min/Max Instructions
|
|
case RISCV::VMINU_VV:
|
|
case RISCV::VMINU_VX:
|
|
case RISCV::VMIN_VV:
|
|
case RISCV::VMIN_VX:
|
|
case RISCV::VMAXU_VV:
|
|
case RISCV::VMAXU_VX:
|
|
case RISCV::VMAX_VV:
|
|
case RISCV::VMAX_VX:
|
|
// Vector Single-Width Integer Multiply Instructions
|
|
case RISCV::VMUL_VV:
|
|
case RISCV::VMUL_VX:
|
|
case RISCV::VMULH_VV:
|
|
case RISCV::VMULH_VX:
|
|
case RISCV::VMULHU_VV:
|
|
case RISCV::VMULHU_VX:
|
|
case RISCV::VMULHSU_VV:
|
|
case RISCV::VMULHSU_VX:
|
|
// Vector Integer Divide Instructions
|
|
case RISCV::VDIVU_VV:
|
|
case RISCV::VDIVU_VX:
|
|
case RISCV::VDIV_VV:
|
|
case RISCV::VDIV_VX:
|
|
case RISCV::VREMU_VV:
|
|
case RISCV::VREMU_VX:
|
|
case RISCV::VREM_VV:
|
|
case RISCV::VREM_VX:
|
|
// Vector Widening Integer Multiply Instructions
|
|
case RISCV::VWMUL_VV:
|
|
case RISCV::VWMUL_VX:
|
|
case RISCV::VWMULSU_VV:
|
|
case RISCV::VWMULSU_VX:
|
|
case RISCV::VWMULU_VV:
|
|
case RISCV::VWMULU_VX:
|
|
// Vector Single-Width Integer Multiply-Add Instructions
|
|
case RISCV::VMACC_VV:
|
|
case RISCV::VMACC_VX:
|
|
case RISCV::VNMSAC_VV:
|
|
case RISCV::VNMSAC_VX:
|
|
case RISCV::VMADD_VV:
|
|
case RISCV::VMADD_VX:
|
|
case RISCV::VNMSUB_VV:
|
|
case RISCV::VNMSUB_VX:
|
|
// Vector Integer Merge Instructions
|
|
case RISCV::VMERGE_VIM:
|
|
case RISCV::VMERGE_VVM:
|
|
case RISCV::VMERGE_VXM:
|
|
// Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
|
|
case RISCV::VADC_VIM:
|
|
case RISCV::VADC_VVM:
|
|
case RISCV::VADC_VXM:
|
|
// Vector Widening Integer Multiply-Add Instructions
|
|
case RISCV::VWMACCU_VV:
|
|
case RISCV::VWMACCU_VX:
|
|
case RISCV::VWMACC_VV:
|
|
case RISCV::VWMACC_VX:
|
|
case RISCV::VWMACCSU_VV:
|
|
case RISCV::VWMACCSU_VX:
|
|
case RISCV::VWMACCUS_VX:
|
|
// Vector Integer Merge Instructions
|
|
// FIXME: Add support
|
|
// Vector Integer Move Instructions
|
|
// FIXME: Add support
|
|
case RISCV::VMV_V_I:
|
|
case RISCV::VMV_V_X:
|
|
case RISCV::VMV_V_V:
|
|
// Vector Single-Width Saturating Add and Subtract
|
|
case RISCV::VSADDU_VV:
|
|
case RISCV::VSADDU_VX:
|
|
case RISCV::VSADDU_VI:
|
|
case RISCV::VSADD_VV:
|
|
case RISCV::VSADD_VX:
|
|
case RISCV::VSADD_VI:
|
|
case RISCV::VSSUBU_VV:
|
|
case RISCV::VSSUBU_VX:
|
|
case RISCV::VSSUB_VV:
|
|
case RISCV::VSSUB_VX:
|
|
// Vector Single-Width Averaging Add and Subtract
|
|
case RISCV::VAADDU_VV:
|
|
case RISCV::VAADDU_VX:
|
|
case RISCV::VAADD_VV:
|
|
case RISCV::VAADD_VX:
|
|
case RISCV::VASUBU_VV:
|
|
case RISCV::VASUBU_VX:
|
|
case RISCV::VASUB_VV:
|
|
case RISCV::VASUB_VX:
|
|
// Vector Single-Width Fractional Multiply with Rounding and Saturation
|
|
case RISCV::VSMUL_VV:
|
|
case RISCV::VSMUL_VX:
|
|
// Vector Single-Width Scaling Shift Instructions
|
|
case RISCV::VSSRL_VV:
|
|
case RISCV::VSSRL_VX:
|
|
case RISCV::VSSRL_VI:
|
|
case RISCV::VSSRA_VV:
|
|
case RISCV::VSSRA_VX:
|
|
case RISCV::VSSRA_VI:
|
|
// Vector Narrowing Fixed-Point Clip Instructions
|
|
case RISCV::VNCLIPU_WV:
|
|
case RISCV::VNCLIPU_WX:
|
|
case RISCV::VNCLIPU_WI:
|
|
case RISCV::VNCLIP_WV:
|
|
case RISCV::VNCLIP_WX:
|
|
case RISCV::VNCLIP_WI:
|
|
|
|
// Vector Crypto
|
|
case RISCV::VWSLL_VI:
|
|
|
|
// Vector Mask Instructions
|
|
// Vector Mask-Register Logical Instructions
|
|
// vmsbf.m set-before-first mask bit
|
|
// vmsif.m set-including-first mask bit
|
|
// vmsof.m set-only-first mask bit
|
|
// Vector Iota Instruction
|
|
// Vector Element Index Instruction
|
|
case RISCV::VMAND_MM:
|
|
case RISCV::VMNAND_MM:
|
|
case RISCV::VMANDN_MM:
|
|
case RISCV::VMXOR_MM:
|
|
case RISCV::VMOR_MM:
|
|
case RISCV::VMNOR_MM:
|
|
case RISCV::VMORN_MM:
|
|
case RISCV::VMXNOR_MM:
|
|
case RISCV::VMSBF_M:
|
|
case RISCV::VMSIF_M:
|
|
case RISCV::VMSOF_M:
|
|
case RISCV::VIOTA_M:
|
|
case RISCV::VID_V:
|
|
// Vector Single-Width Floating-Point Add/Subtract Instructions
|
|
case RISCV::VFADD_VF:
|
|
case RISCV::VFADD_VV:
|
|
case RISCV::VFSUB_VF:
|
|
case RISCV::VFSUB_VV:
|
|
case RISCV::VFRSUB_VF:
|
|
// Vector Widening Floating-Point Add/Subtract Instructions
|
|
case RISCV::VFWADD_VV:
|
|
case RISCV::VFWADD_VF:
|
|
case RISCV::VFWSUB_VV:
|
|
case RISCV::VFWSUB_VF:
|
|
case RISCV::VFWADD_WF:
|
|
case RISCV::VFWADD_WV:
|
|
case RISCV::VFWSUB_WF:
|
|
case RISCV::VFWSUB_WV:
|
|
// Vector Single-Width Floating-Point Multiply/Divide Instructions
|
|
case RISCV::VFMUL_VF:
|
|
case RISCV::VFMUL_VV:
|
|
case RISCV::VFDIV_VF:
|
|
case RISCV::VFDIV_VV:
|
|
case RISCV::VFRDIV_VF:
|
|
// Vector Widening Floating-Point Multiply
|
|
case RISCV::VFWMUL_VF:
|
|
case RISCV::VFWMUL_VV:
|
|
// Vector Single-Width Floating-Point Fused Multiply-Add Instructions
|
|
case RISCV::VFMACC_VV:
|
|
case RISCV::VFMACC_VF:
|
|
case RISCV::VFNMACC_VV:
|
|
case RISCV::VFNMACC_VF:
|
|
case RISCV::VFMSAC_VV:
|
|
case RISCV::VFMSAC_VF:
|
|
case RISCV::VFNMSAC_VV:
|
|
case RISCV::VFNMSAC_VF:
|
|
case RISCV::VFMADD_VV:
|
|
case RISCV::VFMADD_VF:
|
|
case RISCV::VFNMADD_VV:
|
|
case RISCV::VFNMADD_VF:
|
|
case RISCV::VFMSUB_VV:
|
|
case RISCV::VFMSUB_VF:
|
|
case RISCV::VFNMSUB_VV:
|
|
case RISCV::VFNMSUB_VF:
|
|
// Vector Widening Floating-Point Fused Multiply-Add Instructions
|
|
case RISCV::VFWMACC_VV:
|
|
case RISCV::VFWMACC_VF:
|
|
case RISCV::VFWNMACC_VV:
|
|
case RISCV::VFWNMACC_VF:
|
|
case RISCV::VFWMSAC_VV:
|
|
case RISCV::VFWMSAC_VF:
|
|
case RISCV::VFWNMSAC_VV:
|
|
case RISCV::VFWNMSAC_VF:
|
|
case RISCV::VFWMACCBF16_VV:
|
|
case RISCV::VFWMACCBF16_VF:
|
|
// Vector Floating-Point Square-Root Instruction
|
|
case RISCV::VFSQRT_V:
|
|
// Vector Floating-Point Reciprocal Square-Root Estimate Instruction
|
|
case RISCV::VFRSQRT7_V:
|
|
// Vector Floating-Point MIN/MAX Instructions
|
|
case RISCV::VFMIN_VF:
|
|
case RISCV::VFMIN_VV:
|
|
case RISCV::VFMAX_VF:
|
|
case RISCV::VFMAX_VV:
|
|
// Vector Floating-Point Sign-Injection Instructions
|
|
case RISCV::VFSGNJ_VF:
|
|
case RISCV::VFSGNJ_VV:
|
|
case RISCV::VFSGNJN_VV:
|
|
case RISCV::VFSGNJN_VF:
|
|
case RISCV::VFSGNJX_VF:
|
|
case RISCV::VFSGNJX_VV:
|
|
// Vector Floating-Point Compare Instructions
|
|
case RISCV::VMFEQ_VF:
|
|
case RISCV::VMFEQ_VV:
|
|
case RISCV::VMFNE_VF:
|
|
case RISCV::VMFNE_VV:
|
|
case RISCV::VMFLT_VF:
|
|
case RISCV::VMFLT_VV:
|
|
case RISCV::VMFLE_VF:
|
|
case RISCV::VMFLE_VV:
|
|
case RISCV::VMFGT_VF:
|
|
case RISCV::VMFGE_VF:
|
|
// Single-Width Floating-Point/Integer Type-Convert Instructions
|
|
case RISCV::VFCVT_XU_F_V:
|
|
case RISCV::VFCVT_X_F_V:
|
|
case RISCV::VFCVT_RTZ_XU_F_V:
|
|
case RISCV::VFCVT_RTZ_X_F_V:
|
|
case RISCV::VFCVT_F_XU_V:
|
|
case RISCV::VFCVT_F_X_V:
|
|
// Widening Floating-Point/Integer Type-Convert Instructions
|
|
case RISCV::VFWCVT_XU_F_V:
|
|
case RISCV::VFWCVT_X_F_V:
|
|
case RISCV::VFWCVT_RTZ_XU_F_V:
|
|
case RISCV::VFWCVT_RTZ_X_F_V:
|
|
case RISCV::VFWCVT_F_XU_V:
|
|
case RISCV::VFWCVT_F_X_V:
|
|
case RISCV::VFWCVT_F_F_V:
|
|
case RISCV::VFWCVTBF16_F_F_V:
|
|
// Narrowing Floating-Point/Integer Type-Convert Instructions
|
|
case RISCV::VFNCVT_XU_F_W:
|
|
case RISCV::VFNCVT_X_F_W:
|
|
case RISCV::VFNCVT_RTZ_XU_F_W:
|
|
case RISCV::VFNCVT_RTZ_X_F_W:
|
|
case RISCV::VFNCVT_F_XU_W:
|
|
case RISCV::VFNCVT_F_X_W:
|
|
case RISCV::VFNCVT_F_F_W:
|
|
case RISCV::VFNCVT_ROD_F_F_W:
|
|
case RISCV::VFNCVTBF16_F_F_W:
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/// Return true if MO is a vector operand but is used as a scalar operand.
|
|
static bool isVectorOpUsedAsScalarOp(const MachineOperand &MO) {
|
|
const MachineInstr *MI = MO.getParent();
|
|
const RISCVVPseudosTable::PseudoInfo *RVV =
|
|
RISCVVPseudosTable::getPseudoInfo(MI->getOpcode());
|
|
|
|
if (!RVV)
|
|
return false;
|
|
|
|
switch (RVV->BaseInstr) {
|
|
// Reductions only use vs1[0] of vs1
|
|
case RISCV::VREDAND_VS:
|
|
case RISCV::VREDMAX_VS:
|
|
case RISCV::VREDMAXU_VS:
|
|
case RISCV::VREDMIN_VS:
|
|
case RISCV::VREDMINU_VS:
|
|
case RISCV::VREDOR_VS:
|
|
case RISCV::VREDSUM_VS:
|
|
case RISCV::VREDXOR_VS:
|
|
case RISCV::VWREDSUM_VS:
|
|
case RISCV::VWREDSUMU_VS:
|
|
case RISCV::VFREDMAX_VS:
|
|
case RISCV::VFREDMIN_VS:
|
|
case RISCV::VFREDOSUM_VS:
|
|
case RISCV::VFREDUSUM_VS:
|
|
case RISCV::VFWREDOSUM_VS:
|
|
case RISCV::VFWREDUSUM_VS:
|
|
return MO.getOperandNo() == 3;
|
|
case RISCV::VMV_X_S:
|
|
case RISCV::VFMV_F_S:
|
|
return MO.getOperandNo() == 1;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/// Return true if MI may read elements past VL.
|
|
static bool mayReadPastVL(const MachineInstr &MI) {
|
|
const RISCVVPseudosTable::PseudoInfo *RVV =
|
|
RISCVVPseudosTable::getPseudoInfo(MI.getOpcode());
|
|
if (!RVV)
|
|
return true;
|
|
|
|
switch (RVV->BaseInstr) {
|
|
// vslidedown instructions may read elements past VL. They are handled
|
|
// according to current tail policy.
|
|
case RISCV::VSLIDEDOWN_VI:
|
|
case RISCV::VSLIDEDOWN_VX:
|
|
case RISCV::VSLIDE1DOWN_VX:
|
|
case RISCV::VFSLIDE1DOWN_VF:
|
|
|
|
// vrgather instructions may read the source vector at any index < VLMAX,
|
|
// regardless of VL.
|
|
case RISCV::VRGATHER_VI:
|
|
case RISCV::VRGATHER_VV:
|
|
case RISCV::VRGATHER_VX:
|
|
case RISCV::VRGATHEREI16_VV:
|
|
return true;
|
|
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
|
|
const MCInstrDesc &Desc = MI.getDesc();
|
|
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags))
|
|
return false;
|
|
|
|
if (MI.getNumExplicitDefs() != 1)
|
|
return false;
|
|
|
|
// Some instructions have implicit defs e.g. $vxsat. If they might be read
|
|
// later then we can't reduce VL.
|
|
if (!MI.allImplicitDefsAreDead()) {
|
|
LLVM_DEBUG(dbgs() << "Not a candidate because has non-dead implicit def\n");
|
|
return false;
|
|
}
|
|
|
|
if (MI.mayRaiseFPException()) {
|
|
LLVM_DEBUG(dbgs() << "Not a candidate because may raise FP exception\n");
|
|
return false;
|
|
}
|
|
|
|
// Some instructions that produce vectors have semantics that make it more
|
|
// difficult to determine whether the VL can be reduced. For example, some
|
|
// instructions, such as reductions, may write lanes past VL to a scalar
|
|
// register. Other instructions, such as some loads or stores, may write
|
|
// lower lanes using data from higher lanes. There may be other complex
|
|
// semantics not mentioned here that make it hard to determine whether
|
|
// the VL can be optimized. As a result, a white-list of supported
|
|
// instructions is used. Over time, more instructions can be supported
|
|
// upon careful examination of their semantics under the logic in this
|
|
// optimization.
|
|
// TODO: Use a better approach than a white-list, such as adding
|
|
// properties to instructions using something like TSFlags.
|
|
if (!isSupportedInstr(MI)) {
|
|
LLVM_DEBUG(dbgs() << "Not a candidate due to unsupported instruction\n");
|
|
return false;
|
|
}
|
|
|
|
assert(MI.getOperand(0).isReg() &&
|
|
isVectorRegClass(MI.getOperand(0).getReg(), MRI) &&
|
|
"All supported instructions produce a vector register result");
|
|
|
|
LLVM_DEBUG(dbgs() << "Found a candidate for VL reduction: " << MI << "\n");
|
|
return true;
|
|
}
|
|
|
|
std::optional<MachineOperand>
|
|
RISCVVLOptimizer::getMinimumVLForUser(const MachineOperand &UserOp) const {
|
|
const MachineInstr &UserMI = *UserOp.getParent();
|
|
const MCInstrDesc &Desc = UserMI.getDesc();
|
|
|
|
if (!RISCVII::hasVLOp(Desc.TSFlags) || !RISCVII::hasSEWOp(Desc.TSFlags)) {
|
|
LLVM_DEBUG(dbgs() << " Abort due to lack of VL, assume that"
|
|
" use VLMAX\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (mayReadPastVL(UserMI)) {
|
|
LLVM_DEBUG(dbgs() << " Abort because used by unsafe instruction\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
unsigned VLOpNum = RISCVII::getVLOpNum(Desc);
|
|
const MachineOperand &VLOp = UserMI.getOperand(VLOpNum);
|
|
// Looking for an immediate or a register VL that isn't X0.
|
|
assert((!VLOp.isReg() || VLOp.getReg() != RISCV::X0) &&
|
|
"Did not expect X0 VL");
|
|
|
|
// If the user is a passthru it will read the elements past VL, so
|
|
// abort if any of the elements past VL are demanded.
|
|
if (UserOp.isTied()) {
|
|
assert(UserOp.getOperandNo() == UserMI.getNumExplicitDefs() &&
|
|
RISCVII::isFirstDefTiedToFirstUse(UserMI.getDesc()));
|
|
auto DemandedVL = DemandedVLs.lookup(&UserMI);
|
|
if (!DemandedVL || !RISCV::isVLKnownLE(*DemandedVL, VLOp)) {
|
|
LLVM_DEBUG(dbgs() << " Abort because user is passthru in "
|
|
"instruction with demanded tail\n");
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
// Instructions like reductions may use a vector register as a scalar
|
|
// register. In this case, we should treat it as only reading the first lane.
|
|
if (isVectorOpUsedAsScalarOp(UserOp)) {
|
|
[[maybe_unused]] Register R = UserOp.getReg();
|
|
[[maybe_unused]] const TargetRegisterClass *RC = MRI->getRegClass(R);
|
|
assert(RISCV::VRRegClass.hasSubClassEq(RC) &&
|
|
"Expect LMUL 1 register class for vector as scalar operands!");
|
|
LLVM_DEBUG(dbgs() << " Used this operand as a scalar operand\n");
|
|
|
|
return MachineOperand::CreateImm(1);
|
|
}
|
|
|
|
// If we know the demanded VL of UserMI, then we can reduce the VL it
|
|
// requires.
|
|
if (auto DemandedVL = DemandedVLs.lookup(&UserMI)) {
|
|
assert(isCandidate(UserMI));
|
|
if (RISCV::isVLKnownLE(*DemandedVL, VLOp))
|
|
return DemandedVL;
|
|
}
|
|
|
|
return VLOp;
|
|
}
|
|
|
|
std::optional<MachineOperand>
|
|
RISCVVLOptimizer::checkUsers(const MachineInstr &MI) const {
|
|
std::optional<MachineOperand> CommonVL;
|
|
SmallSetVector<MachineOperand *, 8> Worklist;
|
|
SmallPtrSet<const MachineInstr *, 4> PHISeen;
|
|
for (auto &UserOp : MRI->use_operands(MI.getOperand(0).getReg()))
|
|
Worklist.insert(&UserOp);
|
|
|
|
while (!Worklist.empty()) {
|
|
MachineOperand &UserOp = *Worklist.pop_back_val();
|
|
const MachineInstr &UserMI = *UserOp.getParent();
|
|
LLVM_DEBUG(dbgs() << " Checking user: " << UserMI << "\n");
|
|
|
|
if (UserMI.isCopy() && UserMI.getOperand(0).getReg().isVirtual() &&
|
|
UserMI.getOperand(0).getSubReg() == RISCV::NoSubRegister &&
|
|
UserMI.getOperand(1).getSubReg() == RISCV::NoSubRegister) {
|
|
LLVM_DEBUG(dbgs() << " Peeking through uses of COPY\n");
|
|
for (auto &CopyUse : MRI->use_operands(UserMI.getOperand(0).getReg()))
|
|
Worklist.insert(&CopyUse);
|
|
continue;
|
|
}
|
|
|
|
if (UserMI.isPHI()) {
|
|
// Don't follow PHI cycles
|
|
if (!PHISeen.insert(&UserMI).second)
|
|
continue;
|
|
LLVM_DEBUG(dbgs() << " Peeking through uses of PHI\n");
|
|
for (auto &PhiUse : MRI->use_operands(UserMI.getOperand(0).getReg()))
|
|
Worklist.insert(&PhiUse);
|
|
continue;
|
|
}
|
|
|
|
auto VLOp = getMinimumVLForUser(UserOp);
|
|
if (!VLOp)
|
|
return std::nullopt;
|
|
|
|
// Use the largest VL among all the users. If we cannot determine this
|
|
// statically, then we cannot optimize the VL.
|
|
if (!CommonVL || RISCV::isVLKnownLE(*CommonVL, *VLOp)) {
|
|
CommonVL = *VLOp;
|
|
LLVM_DEBUG(dbgs() << " User VL is: " << VLOp << "\n");
|
|
} else if (!RISCV::isVLKnownLE(*VLOp, *CommonVL)) {
|
|
LLVM_DEBUG(dbgs() << " Abort because cannot determine a common VL\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (!RISCVII::hasSEWOp(UserMI.getDesc().TSFlags)) {
|
|
LLVM_DEBUG(dbgs() << " Abort due to lack of SEW operand\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::optional<OperandInfo> ConsumerInfo = getOperandInfo(UserOp, MRI);
|
|
std::optional<OperandInfo> ProducerInfo =
|
|
getOperandInfo(MI.getOperand(0), MRI);
|
|
if (!ConsumerInfo || !ProducerInfo) {
|
|
LLVM_DEBUG(dbgs() << " Abort due to unknown operand information.\n");
|
|
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
|
|
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
// If the operand is used as a scalar operand, then the EEW must be
|
|
// compatible. Otherwise, the EMUL *and* EEW must be compatible.
|
|
bool IsVectorOpUsedAsScalarOp = isVectorOpUsedAsScalarOp(UserOp);
|
|
if ((IsVectorOpUsedAsScalarOp &&
|
|
!OperandInfo::EEWAreEqual(*ConsumerInfo, *ProducerInfo)) ||
|
|
(!IsVectorOpUsedAsScalarOp &&
|
|
!OperandInfo::EMULAndEEWAreEqual(*ConsumerInfo, *ProducerInfo))) {
|
|
LLVM_DEBUG(
|
|
dbgs()
|
|
<< " Abort due to incompatible information for EMUL or EEW.\n");
|
|
LLVM_DEBUG(dbgs() << " ConsumerInfo is: " << ConsumerInfo << "\n");
|
|
LLVM_DEBUG(dbgs() << " ProducerInfo is: " << ProducerInfo << "\n");
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
return CommonVL;
|
|
}
|
|
|
|
bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) const {
|
|
LLVM_DEBUG(dbgs() << "Trying to reduce VL for " << MI << "\n");
|
|
|
|
unsigned VLOpNum = RISCVII::getVLOpNum(MI.getDesc());
|
|
MachineOperand &VLOp = MI.getOperand(VLOpNum);
|
|
|
|
// If the VL is 1, then there is no need to reduce it. This is an
|
|
// optimization, not needed to preserve correctness.
|
|
if (VLOp.isImm() && VLOp.getImm() == 1) {
|
|
LLVM_DEBUG(dbgs() << " Abort due to VL == 1, no point in reducing.\n");
|
|
return false;
|
|
}
|
|
|
|
auto CommonVL = DemandedVLs.lookup(&MI);
|
|
if (!CommonVL)
|
|
return false;
|
|
|
|
assert((CommonVL->isImm() || CommonVL->getReg().isVirtual()) &&
|
|
"Expected VL to be an Imm or virtual Reg");
|
|
|
|
if (!RISCV::isVLKnownLE(*CommonVL, VLOp)) {
|
|
LLVM_DEBUG(dbgs() << " Abort due to CommonVL not <= VLOp.\n");
|
|
return false;
|
|
}
|
|
|
|
if (CommonVL->isIdenticalTo(VLOp)) {
|
|
LLVM_DEBUG(
|
|
dbgs() << " Abort due to CommonVL == VLOp, no point in reducing.\n");
|
|
return false;
|
|
}
|
|
|
|
if (CommonVL->isImm()) {
|
|
LLVM_DEBUG(dbgs() << " Reduce VL from " << VLOp << " to "
|
|
<< CommonVL->getImm() << " for " << MI << "\n");
|
|
VLOp.ChangeToImmediate(CommonVL->getImm());
|
|
return true;
|
|
}
|
|
const MachineInstr *VLMI = MRI->getVRegDef(CommonVL->getReg());
|
|
if (!MDT->dominates(VLMI, &MI))
|
|
return false;
|
|
LLVM_DEBUG(
|
|
dbgs() << " Reduce VL from " << VLOp << " to "
|
|
<< printReg(CommonVL->getReg(), MRI->getTargetRegisterInfo())
|
|
<< " for " << MI << "\n");
|
|
|
|
// All our checks passed. We can reduce VL.
|
|
VLOp.ChangeToRegister(CommonVL->getReg(), false);
|
|
return true;
|
|
}
|
|
|
|
bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
|
|
assert(DemandedVLs.size() == 0);
|
|
if (skipFunction(MF.getFunction()))
|
|
return false;
|
|
|
|
MRI = &MF.getRegInfo();
|
|
MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
|
|
|
|
const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
|
|
if (!ST.hasVInstructions())
|
|
return false;
|
|
|
|
// For each instruction that defines a vector, compute what VL its
|
|
// downstream users demand.
|
|
for (MachineBasicBlock *MBB : post_order(&MF)) {
|
|
assert(MDT->isReachableFromEntry(MBB));
|
|
for (MachineInstr &MI : reverse(*MBB)) {
|
|
if (!isCandidate(MI))
|
|
continue;
|
|
DemandedVLs.insert({&MI, checkUsers(MI)});
|
|
}
|
|
}
|
|
|
|
// Then go through and see if we can reduce the VL of any instructions to
|
|
// only what's demanded.
|
|
bool MadeChange = false;
|
|
for (MachineBasicBlock &MBB : MF) {
|
|
// Avoid unreachable blocks as they have degenerate dominance
|
|
if (!MDT->isReachableFromEntry(&MBB))
|
|
continue;
|
|
|
|
for (auto &MI : reverse(MBB)) {
|
|
if (!isCandidate(MI))
|
|
continue;
|
|
if (!tryReduceVL(MI))
|
|
continue;
|
|
MadeChange = true;
|
|
}
|
|
}
|
|
|
|
DemandedVLs.clear();
|
|
return MadeChange;
|
|
}
|