[TTI] Plumb CostKind through getPartialReductionCost (#144953)
Purely for the sake of being idiomatic with other TTI costing routines, no direct motivation beyond that.
This commit is contained in:
parent
dfb5cadf5e
commit
b96370131d
@ -1332,8 +1332,8 @@ public:
|
|||||||
LLVM_ABI InstructionCost getPartialReductionCost(
|
LLVM_ABI InstructionCost getPartialReductionCost(
|
||||||
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
ElementCount VF, PartialReductionExtendKind OpAExtend,
|
ElementCount VF, PartialReductionExtendKind OpAExtend,
|
||||||
PartialReductionExtendKind OpBExtend,
|
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
std::optional<unsigned> BinOp = std::nullopt) const;
|
TTI::TargetCostKind CostKind) const;
|
||||||
|
|
||||||
/// \return The maximum interleave factor that any transform should try to
|
/// \return The maximum interleave factor that any transform should try to
|
||||||
/// perform for this target. This number depends on the level of parallelism
|
/// perform for this target. This number depends on the level of parallelism
|
||||||
|
@ -652,12 +652,11 @@ public:
|
|||||||
virtual bool enableWritePrefetching() const { return false; }
|
virtual bool enableWritePrefetching() const { return false; }
|
||||||
virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
|
virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
|
||||||
|
|
||||||
virtual InstructionCost
|
virtual InstructionCost getPartialReductionCost(
|
||||||
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
Type *AccumType, ElementCount VF,
|
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
||||||
TTI::PartialReductionExtendKind OpAExtend,
|
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
TTI::PartialReductionExtendKind OpBExtend,
|
TTI::TargetCostKind CostKind) const {
|
||||||
std::optional<unsigned> BinOp = std::nullopt) const {
|
|
||||||
return InstructionCost::getInvalid();
|
return InstructionCost::getInvalid();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
|
|||||||
InstructionCost TargetTransformInfo::getPartialReductionCost(
|
InstructionCost TargetTransformInfo::getPartialReductionCost(
|
||||||
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
ElementCount VF, PartialReductionExtendKind OpAExtend,
|
ElementCount VF, PartialReductionExtendKind OpAExtend,
|
||||||
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
|
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
|
TTI::TargetCostKind CostKind) const {
|
||||||
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
|
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
|
||||||
AccumType, VF, OpAExtend, OpBExtend,
|
AccumType, VF, OpAExtend, OpBExtend,
|
||||||
BinOp);
|
BinOp, CostKind);
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
|
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {
|
||||||
|
@ -5395,11 +5395,14 @@ AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
|
|||||||
InstructionCost AArch64TTIImpl::getPartialReductionCost(
|
InstructionCost AArch64TTIImpl::getPartialReductionCost(
|
||||||
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
||||||
TTI::PartialReductionExtendKind OpBExtend,
|
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
std::optional<unsigned> BinOp) const {
|
TTI::TargetCostKind CostKind) const {
|
||||||
InstructionCost Invalid = InstructionCost::getInvalid();
|
InstructionCost Invalid = InstructionCost::getInvalid();
|
||||||
InstructionCost Cost(TTI::TCC_Basic);
|
InstructionCost Cost(TTI::TCC_Basic);
|
||||||
|
|
||||||
|
if (CostKind != TTI::TCK_RecipThroughput)
|
||||||
|
return Invalid;
|
||||||
|
|
||||||
// Sub opcodes currently only occur in chained cases.
|
// Sub opcodes currently only occur in chained cases.
|
||||||
// Independent partial reduction subtractions are still costed as an add
|
// Independent partial reduction subtractions are still costed as an add
|
||||||
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
|
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
|
||||||
|
@ -382,12 +382,11 @@ public:
|
|||||||
return BaseT::isLegalNTLoad(DataType, Alignment);
|
return BaseT::isLegalNTLoad(DataType, Alignment);
|
||||||
}
|
}
|
||||||
|
|
||||||
InstructionCost
|
InstructionCost getPartialReductionCost(
|
||||||
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
Type *AccumType, ElementCount VF,
|
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
||||||
TTI::PartialReductionExtendKind OpAExtend,
|
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
TTI::PartialReductionExtendKind OpBExtend,
|
TTI::TargetCostKind CostKind) const override;
|
||||||
std::optional<unsigned> BinOp) const override;
|
|
||||||
|
|
||||||
bool enableOrderedReductions() const override { return true; }
|
bool enableOrderedReductions() const override { return true; }
|
||||||
|
|
||||||
|
@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
|
|||||||
InstructionCost RISCVTTIImpl::getPartialReductionCost(
|
InstructionCost RISCVTTIImpl::getPartialReductionCost(
|
||||||
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
||||||
TTI::PartialReductionExtendKind OpBExtend,
|
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
std::optional<unsigned> BinOp) const {
|
TTI::TargetCostKind CostKind) const {
|
||||||
|
|
||||||
// zve32x is broken for partial_reduce_umla, but let's make sure we
|
// zve32x is broken for partial_reduce_umla, but let's make sure we
|
||||||
// don't generate them.
|
// don't generate them.
|
||||||
@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
|
|||||||
Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4));
|
Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4));
|
||||||
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
|
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
|
||||||
// Note: Asuming all vqdot* variants are equal cost
|
// Note: Asuming all vqdot* variants are equal cost
|
||||||
// TODO: Thread CostKind through this API
|
return LT.first *
|
||||||
return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
|
getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, CostKind);
|
||||||
TTI::TCK_RecipThroughput);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
|
bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {
|
||||||
|
@ -100,12 +100,11 @@ public:
|
|||||||
TargetTransformInfo::PopcntSupportKind
|
TargetTransformInfo::PopcntSupportKind
|
||||||
getPopcntSupport(unsigned TyWidth) const override;
|
getPopcntSupport(unsigned TyWidth) const override;
|
||||||
|
|
||||||
InstructionCost
|
InstructionCost getPartialReductionCost(
|
||||||
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
Type *AccumType, ElementCount VF,
|
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
||||||
TTI::PartialReductionExtendKind OpAExtend,
|
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
TTI::PartialReductionExtendKind OpBExtend,
|
TTI::TargetCostKind CostKind) const override;
|
||||||
std::optional<unsigned> BinOp) const override;
|
|
||||||
|
|
||||||
bool shouldExpandReduction(const IntrinsicInst *II) const override;
|
bool shouldExpandReduction(const IntrinsicInst *II) const override;
|
||||||
bool supportsScalableVectors() const override {
|
bool supportsScalableVectors() const override {
|
||||||
|
@ -198,12 +198,15 @@ InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
|
|||||||
InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
|
InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
|
||||||
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
||||||
TTI::PartialReductionExtendKind OpBExtend,
|
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
std::optional<unsigned> BinOp) const {
|
TTI::TargetCostKind CostKind) const {
|
||||||
InstructionCost Invalid = InstructionCost::getInvalid();
|
InstructionCost Invalid = InstructionCost::getInvalid();
|
||||||
if (!VF.isFixed() || !ST->hasSIMD128())
|
if (!VF.isFixed() || !ST->hasSIMD128())
|
||||||
return Invalid;
|
return Invalid;
|
||||||
|
|
||||||
|
if (CostKind != TTI::TCK_RecipThroughput)
|
||||||
|
return Invalid;
|
||||||
|
|
||||||
InstructionCost Cost(TTI::TCC_Basic);
|
InstructionCost Cost(TTI::TCC_Basic);
|
||||||
|
|
||||||
// Possible options:
|
// Possible options:
|
||||||
|
@ -86,8 +86,8 @@ public:
|
|||||||
InstructionCost getPartialReductionCost(
|
InstructionCost getPartialReductionCost(
|
||||||
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
|
||||||
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
|
||||||
TTI::PartialReductionExtendKind OpBExtend,
|
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
|
||||||
std::optional<unsigned> BinOp = std::nullopt) const override;
|
TTI::TargetCostKind CostKind) const override;
|
||||||
TTI::ReductionShuffle
|
TTI::ReductionShuffle
|
||||||
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
|
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
|
||||||
|
|
||||||
|
@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
|
|||||||
[&](ElementCount VF) {
|
[&](ElementCount VF) {
|
||||||
InstructionCost Cost = TTI->getPartialReductionCost(
|
InstructionCost Cost = TTI->getPartialReductionCost(
|
||||||
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
|
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
|
||||||
VF, OpAExtend, OpBExtend, BinOp->getOpcode());
|
VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
|
||||||
return Cost.isValid();
|
return Cost.isValid();
|
||||||
},
|
},
|
||||||
Range)) {
|
Range)) {
|
||||||
|
@ -336,9 +336,9 @@ VPPartialReductionRecipe::computeCost(ElementCount VF,
|
|||||||
return TargetTransformInfo::PR_None;
|
return TargetTransformInfo::PR_None;
|
||||||
};
|
};
|
||||||
|
|
||||||
return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB,
|
return Ctx.TTI.getPartialReductionCost(
|
||||||
PhiType, VF, GetExtendKind(ExtAR),
|
getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR),
|
||||||
GetExtendKind(ExtBR), Opcode);
|
GetExtendKind(ExtBR), Opcode, Ctx.CostKind);
|
||||||
}
|
}
|
||||||
|
|
||||||
void VPPartialReductionRecipe::execute(VPTransformState &State) {
|
void VPPartialReductionRecipe::execute(VPTransformState &State) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user