From b96370131d1572feb9c51442ac8ba1ccb16d7071 Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Thu, 19 Jun 2025 15:29:56 -0700 Subject: [PATCH] [TTI] Plumb CostKind through getPartialReductionCost (#144953) Purely for the sake of being idiomatic with other TTI costing routines, no direct motivation beyond that. --- llvm/include/llvm/Analysis/TargetTransformInfo.h | 4 ++-- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h | 11 +++++------ llvm/lib/Analysis/TargetTransformInfo.cpp | 5 +++-- .../lib/Target/AArch64/AArch64TargetTransformInfo.cpp | 7 +++++-- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h | 11 +++++------ llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp | 9 ++++----- llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h | 11 +++++------ .../WebAssembly/WebAssemblyTargetTransformInfo.cpp | 7 +++++-- .../WebAssembly/WebAssemblyTargetTransformInfo.h | 4 ++-- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp | 2 +- llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 6 +++--- 11 files changed, 40 insertions(+), 37 deletions(-) diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 9dc4eca82492..ba47cef274be 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1332,8 +1332,8 @@ public: LLVM_ABI InstructionCost getPartialReductionCost( unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, ElementCount VF, PartialReductionExtendKind OpAExtend, - PartialReductionExtendKind OpBExtend, - std::optional BinOp = std::nullopt) const; + PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const; /// \return The maximum interleave factor that any transform should try to /// perform for this target. This number depends on the level of parallelism diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index d93375218394..640766cf8cd1 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -652,12 +652,11 @@ public: virtual bool enableWritePrefetching() const { return false; } virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; } - virtual InstructionCost - getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB, - Type *AccumType, ElementCount VF, - TTI::PartialReductionExtendKind OpAExtend, - TTI::PartialReductionExtendKind OpBExtend, - std::optional BinOp = std::nullopt) const { + virtual InstructionCost getPartialReductionCost( + unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, + ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, + TTI::PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const { return InstructionCost::getInvalid(); } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index d9cb11de9c09..8cc7f8a9d2ab 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const { InstructionCost TargetTransformInfo::getPartialReductionCost( unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, ElementCount VF, PartialReductionExtendKind OpAExtend, - PartialReductionExtendKind OpBExtend, std::optional BinOp) const { + PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const { return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB, AccumType, VF, OpAExtend, OpBExtend, - BinOp); + BinOp, CostKind); } unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const { diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index ed051f295752..9d5c984fa4f1 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -5395,11 +5395,14 @@ AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index, InstructionCost AArch64TTIImpl::getPartialReductionCost( unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, - TTI::PartialReductionExtendKind OpBExtend, - std::optional BinOp) const { + TTI::PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const { InstructionCost Invalid = InstructionCost::getInvalid(); InstructionCost Cost(TTI::TCC_Basic); + if (CostKind != TTI::TCK_RecipThroughput) + return Invalid; + // Sub opcodes currently only occur in chained cases. // Independent partial reduction subtractions are still costed as an add if (Opcode != Instruction::Add && Opcode != Instruction::Sub) diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 0184e748b3d8..470af01be315 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -382,12 +382,11 @@ public: return BaseT::isLegalNTLoad(DataType, Alignment); } - InstructionCost - getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB, - Type *AccumType, ElementCount VF, - TTI::PartialReductionExtendKind OpAExtend, - TTI::PartialReductionExtendKind OpBExtend, - std::optional BinOp) const override; + InstructionCost getPartialReductionCost( + unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, + ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, + TTI::PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const override; bool enableOrderedReductions() const override { return true; } diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index 63c5f17a8487..1b80b0fcaf10 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const { InstructionCost RISCVTTIImpl::getPartialReductionCost( unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, - TTI::PartialReductionExtendKind OpBExtend, - std::optional BinOp) const { + TTI::PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const { // zve32x is broken for partial_reduce_umla, but let's make sure we // don't generate them. @@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost( Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4)); std::pair LT = getTypeLegalizationCost(Tp); // Note: Asuming all vqdot* variants are equal cost - // TODO: Thread CostKind through this API - return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, - TTI::TCK_RecipThroughput); + return LT.first * + getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, CostKind); } bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const { diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index 75d377abb0e7..83ac71ed9da6 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -100,12 +100,11 @@ public: TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth) const override; - InstructionCost - getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB, - Type *AccumType, ElementCount VF, - TTI::PartialReductionExtendKind OpAExtend, - TTI::PartialReductionExtendKind OpBExtend, - std::optional BinOp) const override; + InstructionCost getPartialReductionCost( + unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, + ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, + TTI::PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const override; bool shouldExpandReduction(const IntrinsicInst *II) const override; bool supportsScalableVectors() const override { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp index 978e08bb8955..4f159996e4c6 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp @@ -198,12 +198,15 @@ InstructionCost WebAssemblyTTIImpl::getVectorInstrCost( InstructionCost WebAssemblyTTIImpl::getPartialReductionCost( unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, - TTI::PartialReductionExtendKind OpBExtend, - std::optional BinOp) const { + TTI::PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const { InstructionCost Invalid = InstructionCost::getInvalid(); if (!VF.isFixed() || !ST->hasSIMD128()) return Invalid; + if (CostKind != TTI::TCK_RecipThroughput) + return Invalid; + InstructionCost Cost(TTI::TCC_Basic); // Possible options: diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h index 6b6d060076a8..d83b8d1f45db 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h @@ -86,8 +86,8 @@ public: InstructionCost getPartialReductionCost( unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType, ElementCount VF, TTI::PartialReductionExtendKind OpAExtend, - TTI::PartialReductionExtendKind OpBExtend, - std::optional BinOp = std::nullopt) const override; + TTI::PartialReductionExtendKind OpBExtend, std::optional BinOp, + TTI::TargetCostKind CostKind) const override; TTI::ReductionShuffle getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index e14f985efd96..9a2cd94eda58 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions( [&](ElementCount VF) { InstructionCost Cost = TTI->getPartialReductionCost( Update->getOpcode(), A->getType(), B->getType(), PHI->getType(), - VF, OpAExtend, OpBExtend, BinOp->getOpcode()); + VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind); return Cost.isValid(); }, Range)) { diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index f3b5c8cfa988..22861eb1c7df 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -336,9 +336,9 @@ VPPartialReductionRecipe::computeCost(ElementCount VF, return TargetTransformInfo::PR_None; }; - return Ctx.TTI.getPartialReductionCost(getOpcode(), InputTypeA, InputTypeB, - PhiType, VF, GetExtendKind(ExtAR), - GetExtendKind(ExtBR), Opcode); + return Ctx.TTI.getPartialReductionCost( + getOpcode(), InputTypeA, InputTypeB, PhiType, VF, GetExtendKind(ExtAR), + GetExtendKind(ExtBR), Opcode, Ctx.CostKind); } void VPPartialReductionRecipe::execute(VPTransformState &State) {