[TTI] Plumb CostKind through getPartialReductionCost (#144953)

Purely for the sake of being idiomatic with other TTI costing routines,
no direct motivation beyond that.
This commit is contained in:
Philip Reames 2025-06-19 15:29:56 -07:00 committed by GitHub
parent dfb5cadf5e
commit b96370131d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 40 additions and 37 deletions

View File

@ -1332,8 +1332,8 @@ public:
LLVM_ABI InstructionCost getPartialReductionCost( LLVM_ABI InstructionCost getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend, ElementCount VF, PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend, PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
std::optional<unsigned> BinOp = std::nullopt) const; TTI::TargetCostKind CostKind) const;
/// \return The maximum interleave factor that any transform should try to /// \return The maximum interleave factor that any transform should try to
/// perform for this target. This number depends on the level of parallelism /// perform for this target. This number depends on the level of parallelism

View File

@ -652,12 +652,11 @@ public:
virtual bool enableWritePrefetching() const { return false; } virtual bool enableWritePrefetching() const { return false; }
virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; } virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
virtual InstructionCost virtual InstructionCost getPartialReductionCost(
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
Type *AccumType, ElementCount VF, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpAExtend, TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
TTI::PartialReductionExtendKind OpBExtend, TTI::TargetCostKind CostKind) const {
std::optional<unsigned> BinOp = std::nullopt) const {
return InstructionCost::getInvalid(); return InstructionCost::getInvalid();
} }

View File

@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
InstructionCost TargetTransformInfo::getPartialReductionCost( InstructionCost TargetTransformInfo::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, PartialReductionExtendKind OpAExtend, ElementCount VF, PartialReductionExtendKind OpAExtend,
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const { PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
TTI::TargetCostKind CostKind) const {
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB, return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
AccumType, VF, OpAExtend, OpBExtend, AccumType, VF, OpAExtend, OpBExtend,
BinOp); BinOp, CostKind);
} }
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const { unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {

View File

@ -5395,11 +5395,14 @@ AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
InstructionCost AArch64TTIImpl::getPartialReductionCost( InstructionCost AArch64TTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend, TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
std::optional<unsigned> BinOp) const { TTI::TargetCostKind CostKind) const {
InstructionCost Invalid = InstructionCost::getInvalid(); InstructionCost Invalid = InstructionCost::getInvalid();
InstructionCost Cost(TTI::TCC_Basic); InstructionCost Cost(TTI::TCC_Basic);
if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;
// Sub opcodes currently only occur in chained cases. // Sub opcodes currently only occur in chained cases.
// Independent partial reduction subtractions are still costed as an add // Independent partial reduction subtractions are still costed as an add
if (Opcode != Instruction::Add && Opcode != Instruction::Sub) if (Opcode != Instruction::Add && Opcode != Instruction::Sub)

View File

@ -382,12 +382,11 @@ public:
return BaseT::isLegalNTLoad(DataType, Alignment); return BaseT::isLegalNTLoad(DataType, Alignment);
} }
InstructionCost InstructionCost getPartialReductionCost(
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
Type *AccumType, ElementCount VF, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpAExtend, TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
TTI::PartialReductionExtendKind OpBExtend, TTI::TargetCostKind CostKind) const override;
std::optional<unsigned> BinOp) const override;
bool enableOrderedReductions() const override { return true; } bool enableOrderedReductions() const override { return true; }

View File

@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
InstructionCost RISCVTTIImpl::getPartialReductionCost( InstructionCost RISCVTTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend, TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
std::optional<unsigned> BinOp) const { TTI::TargetCostKind CostKind) const {
// zve32x is broken for partial_reduce_umla, but let's make sure we // zve32x is broken for partial_reduce_umla, but let's make sure we
// don't generate them. // don't generate them.
@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4)); Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4));
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp); std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
// Note: Asuming all vqdot* variants are equal cost // Note: Asuming all vqdot* variants are equal cost
// TODO: Thread CostKind through this API return LT.first *
return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, CostKind);
TTI::TCK_RecipThroughput);
} }
bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const { bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {

View File

@ -100,12 +100,11 @@ public:
TargetTransformInfo::PopcntSupportKind TargetTransformInfo::PopcntSupportKind
getPopcntSupport(unsigned TyWidth) const override; getPopcntSupport(unsigned TyWidth) const override;
InstructionCost InstructionCost getPartialReductionCost(
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
Type *AccumType, ElementCount VF, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpAExtend, TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
TTI::PartialReductionExtendKind OpBExtend, TTI::TargetCostKind CostKind) const override;
std::optional<unsigned> BinOp) const override;
bool shouldExpandReduction(const IntrinsicInst *II) const override; bool shouldExpandReduction(const IntrinsicInst *II) const override;
bool supportsScalableVectors() const override { bool supportsScalableVectors() const override {

View File

@ -198,12 +198,15 @@ InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
InstructionCost WebAssemblyTTIImpl::getPartialReductionCost( InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend, TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
std::optional<unsigned> BinOp) const { TTI::TargetCostKind CostKind) const {
InstructionCost Invalid = InstructionCost::getInvalid(); InstructionCost Invalid = InstructionCost::getInvalid();
if (!VF.isFixed() || !ST->hasSIMD128()) if (!VF.isFixed() || !ST->hasSIMD128())
return Invalid; return Invalid;
if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;
InstructionCost Cost(TTI::TCC_Basic); InstructionCost Cost(TTI::TCC_Basic);
// Possible options: // Possible options:

View File

@ -86,8 +86,8 @@ public:
InstructionCost getPartialReductionCost( InstructionCost getPartialReductionCost(
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
TTI::PartialReductionExtendKind OpBExtend, TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
std::optional<unsigned> BinOp = std::nullopt) const override; TTI::TargetCostKind CostKind) const override;
TTI::ReductionShuffle TTI::ReductionShuffle
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override; getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;

View File

@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
[&](ElementCount VF) { [&](ElementCount VF) {
InstructionCost Cost = TTI->getPartialReductionCost( InstructionCost Cost = TTI->getPartialReductionCost(
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(), Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
VF, OpAExtend, OpBExtend, BinOp->getOpcode()); VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
return Cost.isValid(); return Cost.isValid();
}, },
Range)) { Range)) {

View File

@ -336,9 +336,9 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
return TargetTransformInfo::PR_None; return TargetTransformInfo::PR_None;
}; };
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB, return Ctx.TTI.getPartialReductionCost(
PhiType, VF, GetExtendKind(ExtAR), getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR),
GetExtendKind(ExtBR), Opcode); GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
} }
void VPPartialReductionRecipe::execute(VPTransformState &State) { void VPPartialReductionRecipe::execute(VPTransformState &State) {