From 018ae02785da5a35bea8a6eb04f6398c052228f9 Mon Sep 17 00:00:00 2001 From: Matheus Izvekov Date: Thu, 9 Oct 2025 00:35:49 -0300 Subject: [PATCH] [clang] fix transform for constant template parameter type subst node (#162587) This fixes the transform to use the correct parameter type for an AssociatedDecl which has been fully specialized. Instead of using the type for the parameter of the specialized template, this uses the type of the argument it has been specialized with. This fixes a regression reported here: https://github.com/llvm/llvm-project/pull/161029#issuecomment-3375478990 Since this regression was never released, there are no release notes. --- clang/include/clang/AST/DeclTemplate.h | 7 ++- clang/lib/AST/DeclTemplate.cpp | 80 +++++++++++++----------- clang/lib/AST/ExprCXX.cpp | 4 +- clang/lib/AST/TemplateName.cpp | 10 ++- clang/lib/AST/Type.cpp | 14 ++--- clang/lib/Sema/TreeTransform.h | 17 +++-- clang/test/CodeGenCXX/template-cxx20.cpp | 24 +++++++ 7 files changed, 93 insertions(+), 63 deletions(-) create mode 100644 clang/test/CodeGenCXX/template-cxx20.cpp diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h index a3c67a60d032..a4a1bb9c13c7 100644 --- a/clang/include/clang/AST/DeclTemplate.h +++ b/clang/include/clang/AST/DeclTemplate.h @@ -3395,9 +3395,10 @@ inline UnsignedOrNone getExpandedPackSize(const NamedDecl *Param) { return std::nullopt; } -/// Internal helper used by Subst* nodes to retrieve the parameter list -/// for their AssociatedDecl. -TemplateParameterList *getReplacedTemplateParameterList(const Decl *D); +/// Internal helper used by Subst* nodes to retrieve a parameter from the +/// AssociatedDecl, and the template argument substituted into it, if any. +std::tuple +getReplacedTemplateParameter(Decl *D, unsigned Index); /// If we have a 'templated' declaration for a template, adjust 'D' to /// refer to the actual template. diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp index e5fba1bf9d6b..c0be986824cf 100644 --- a/clang/lib/AST/DeclTemplate.cpp +++ b/clang/lib/AST/DeclTemplate.cpp @@ -1653,57 +1653,65 @@ void TemplateParamObjectDecl::printAsInit(llvm::raw_ostream &OS, getValue().printPretty(OS, Policy, getType(), &getASTContext()); } -TemplateParameterList *clang::getReplacedTemplateParameterList(const Decl *D) { +std::tuple +clang::getReplacedTemplateParameter(Decl *D, unsigned Index) { switch (D->getKind()) { - case Decl::Kind::CXXRecord: - return cast(D) - ->getDescribedTemplate() - ->getTemplateParameters(); + case Decl::Kind::BuiltinTemplate: case Decl::Kind::ClassTemplate: - return cast(D)->getTemplateParameters(); + case Decl::Kind::Concept: + case Decl::Kind::FunctionTemplate: + case Decl::Kind::TemplateTemplateParm: + case Decl::Kind::TypeAliasTemplate: + case Decl::Kind::VarTemplate: + return {cast(D)->getTemplateParameters()->getParam(Index), + {}}; case Decl::Kind::ClassTemplateSpecialization: { const auto *CTSD = cast(D); auto P = CTSD->getSpecializedTemplateOrPartial(); + TemplateParameterList *TPL; if (const auto *CTPSD = dyn_cast(P)) - return CTPSD->getTemplateParameters(); - return cast(P)->getTemplateParameters(); + TPL = CTPSD->getTemplateParameters(); + else + TPL = cast(P)->getTemplateParameters(); + return {TPL->getParam(Index), CTSD->getTemplateArgs()[Index]}; + } + case Decl::Kind::VarTemplateSpecialization: { + const auto *VTSD = cast(D); + auto P = VTSD->getSpecializedTemplateOrPartial(); + TemplateParameterList *TPL; + if (const auto *VTPSD = dyn_cast(P)) + TPL = VTPSD->getTemplateParameters(); + else + TPL = cast(P)->getTemplateParameters(); + return {TPL->getParam(Index), VTSD->getTemplateArgs()[Index]}; } case Decl::Kind::ClassTemplatePartialSpecialization: - return cast(D) - ->getTemplateParameters(); - case Decl::Kind::TypeAliasTemplate: - return cast(D)->getTemplateParameters(); - case Decl::Kind::BuiltinTemplate: - return cast(D)->getTemplateParameters(); + return {cast(D) + ->getTemplateParameters() + ->getParam(Index), + {}}; + case Decl::Kind::VarTemplatePartialSpecialization: + return {cast(D) + ->getTemplateParameters() + ->getParam(Index), + {}}; + // This is used as the AssociatedDecl for placeholder type deduction. + case Decl::TemplateTypeParm: + return {cast(D), {}}; + // FIXME: Always use the template decl as the AssociatedDecl. + case Decl::Kind::CXXRecord: + return getReplacedTemplateParameter( + cast(D)->getDescribedClassTemplate(), Index); case Decl::Kind::CXXDeductionGuide: case Decl::Kind::CXXConversion: case Decl::Kind::CXXConstructor: case Decl::Kind::CXXDestructor: case Decl::Kind::CXXMethod: case Decl::Kind::Function: - return cast(D) - ->getTemplateSpecializationInfo() - ->getTemplate() - ->getTemplateParameters(); - case Decl::Kind::FunctionTemplate: - return cast(D)->getTemplateParameters(); - case Decl::Kind::VarTemplate: - return cast(D)->getTemplateParameters(); - case Decl::Kind::VarTemplateSpecialization: { - const auto *VTSD = cast(D); - auto P = VTSD->getSpecializedTemplateOrPartial(); - if (const auto *VTPSD = dyn_cast(P)) - return VTPSD->getTemplateParameters(); - return cast(P)->getTemplateParameters(); - } - case Decl::Kind::VarTemplatePartialSpecialization: - return cast(D) - ->getTemplateParameters(); - case Decl::Kind::TemplateTemplateParm: - return cast(D)->getTemplateParameters(); - case Decl::Kind::Concept: - return cast(D)->getTemplateParameters(); + return getReplacedTemplateParameter( + cast(D)->getTemplateSpecializationInfo()->getTemplate(), + Index); default: llvm_unreachable("Unhandled templated declaration kind"); } diff --git a/clang/lib/AST/ExprCXX.cpp b/clang/lib/AST/ExprCXX.cpp index 95de6a82a527..c7f0ff040194 100644 --- a/clang/lib/AST/ExprCXX.cpp +++ b/clang/lib/AST/ExprCXX.cpp @@ -1727,7 +1727,7 @@ SizeOfPackExpr *SizeOfPackExpr::CreateDeserialized(ASTContext &Context, NonTypeTemplateParmDecl *SubstNonTypeTemplateParmExpr::getParameter() const { return cast( - getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]); + std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index))); } PackIndexingExpr *PackIndexingExpr::Create( @@ -1793,7 +1793,7 @@ SubstNonTypeTemplateParmPackExpr::SubstNonTypeTemplateParmPackExpr( NonTypeTemplateParmDecl * SubstNonTypeTemplateParmPackExpr::getParameterPack() const { return cast( - getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]); + std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index))); } TemplateArgument SubstNonTypeTemplateParmPackExpr::getArgumentPack() const { diff --git a/clang/lib/AST/TemplateName.cpp b/clang/lib/AST/TemplateName.cpp index 2b8044e4188c..797a354c5d0f 100644 --- a/clang/lib/AST/TemplateName.cpp +++ b/clang/lib/AST/TemplateName.cpp @@ -64,16 +64,14 @@ SubstTemplateTemplateParmPackStorage::getArgumentPack() const { TemplateTemplateParmDecl * SubstTemplateTemplateParmPackStorage::getParameterPack() const { - return cast( - getReplacedTemplateParameterList(getAssociatedDecl()) - ->asArray()[Bits.Index]); + return cast(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index))); } TemplateTemplateParmDecl * SubstTemplateTemplateParmStorage::getParameter() const { - return cast( - getReplacedTemplateParameterList(getAssociatedDecl()) - ->asArray()[Bits.Index]); + return cast(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index))); } void SubstTemplateTemplateParmStorage::Profile(llvm::FoldingSetNodeID &ID) { diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index 9794314a98f8..ee7a68ee8ba7 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -4436,14 +4436,6 @@ IdentifierInfo *TemplateTypeParmType::getIdentifier() const { return isCanonicalUnqualified() ? nullptr : getDecl()->getIdentifier(); } -static const TemplateTypeParmDecl *getReplacedParameter(Decl *D, - unsigned Index) { - if (const auto *TTP = dyn_cast(D)) - return TTP; - return cast( - getReplacedTemplateParameterList(D)->getParam(Index)); -} - SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, unsigned Index, @@ -4466,7 +4458,8 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement, const TemplateTypeParmDecl * SubstTemplateTypeParmType::getReplacedParameter() const { - return ::getReplacedParameter(getAssociatedDecl(), getIndex()); + return cast(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), getIndex()))); } void SubstTemplateTypeParmType::Profile(llvm::FoldingSetNodeID &ID, @@ -4532,7 +4525,8 @@ bool SubstTemplateTypeParmPackType::getFinal() const { const TemplateTypeParmDecl * SubstTemplateTypeParmPackType::getReplacedParameter() const { - return ::getReplacedParameter(getAssociatedDecl(), getIndex()); + return cast(std::get<0>( + getReplacedTemplateParameter(getAssociatedDecl(), getIndex()))); } IdentifierInfo *SubstTemplateTypeParmPackType::getIdentifier() const { diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 51b55b82f420..940324bbc5e4 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -16364,16 +16364,21 @@ ExprResult TreeTransform::TransformSubstNonTypeTemplateParmExpr( AssociatedDecl == E->getAssociatedDecl()) return E; + auto getParamAndType = [Index = E->getIndex()](Decl *AssociatedDecl) + -> std::tuple { + auto [PDecl, Arg] = getReplacedTemplateParameter(AssociatedDecl, Index); + auto *Param = cast(PDecl); + return {Param, Arg.isNull() ? Param->getType() + : Arg.getNonTypeTemplateArgumentType()}; + }; + // If the replacement expression did not change, and the parameter type // did not change, we can skip the semantic action because it would // produce the same result anyway. - auto *Param = cast( - getReplacedTemplateParameterList(AssociatedDecl) - ->asArray()[E->getIndex()]); - if (QualType ParamType = Param->getType(); - !SemaRef.Context.hasSameType(ParamType, E->getParameter()->getType()) || + if (auto [Param, ParamType] = getParamAndType(AssociatedDecl); + !SemaRef.Context.hasSameType( + ParamType, std::get<1>(getParamAndType(E->getAssociatedDecl()))) || Replacement.get() != OrigReplacement) { - // When transforming the replacement expression previously, all Sema // specific annotations, such as implicit casts, are discarded. Calling the // corresponding sema action is necessary to recover those. Otherwise, diff --git a/clang/test/CodeGenCXX/template-cxx20.cpp b/clang/test/CodeGenCXX/template-cxx20.cpp new file mode 100644 index 000000000000..aeb1cc915145 --- /dev/null +++ b/clang/test/CodeGenCXX/template-cxx20.cpp @@ -0,0 +1,24 @@ +// RUN: %clang_cc1 %s -O0 -disable-llvm-passes -triple=x86_64 -std=c++20 -emit-llvm -o - | FileCheck %s + +namespace GH161029_regression1 { + template auto f(int) { _Fp{}(0); } + template void g() { + (..., f<_Fp>(_Js)); + } + enum E { k }; + template struct ElementAt; + template struct ElementAt<0, First> { + static int value; + }; + template struct TagSet { + template using Tag = ElementAt; + }; + template struct S { + void U() { (void)TagSet::template Tag<0>::value; } + }; + S> s; + void h() { + g void { s.U(); }), 0>(); + } + // CHECK: call void @_ZN20GH161029_regression11SINS_6TagSetINS_1EELS2_0EEEE1UEv +}