[clang] fix transform for constant template parameter type subst node (#162587)

This fixes the transform to use the correct parameter type for an
AssociatedDecl which has been fully specialized.

Instead of using the type for the parameter of the specialized template,
this uses the type of the argument it has been specialized with.

This fixes a regression reported here:
https://github.com/llvm/llvm-project/pull/161029#issuecomment-3375478990

Since this regression was never released, there are no release notes.
This commit is contained in:
Matheus Izvekov 2025-10-09 00:35:49 -03:00 committed by GitHub
parent 822446d74a
commit 018ae02785
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 93 additions and 63 deletions

View File

@ -3395,9 +3395,10 @@ inline UnsignedOrNone getExpandedPackSize(const NamedDecl *Param) {
return std::nullopt;
}
/// Internal helper used by Subst* nodes to retrieve the parameter list
/// for their AssociatedDecl.
TemplateParameterList *getReplacedTemplateParameterList(const Decl *D);
/// Internal helper used by Subst* nodes to retrieve a parameter from the
/// AssociatedDecl, and the template argument substituted into it, if any.
std::tuple<NamedDecl *, TemplateArgument>
getReplacedTemplateParameter(Decl *D, unsigned Index);
/// If we have a 'templated' declaration for a template, adjust 'D' to
/// refer to the actual template.

View File

@ -1653,57 +1653,65 @@ void TemplateParamObjectDecl::printAsInit(llvm::raw_ostream &OS,
getValue().printPretty(OS, Policy, getType(), &getASTContext());
}
TemplateParameterList *clang::getReplacedTemplateParameterList(const Decl *D) {
std::tuple<NamedDecl *, TemplateArgument>
clang::getReplacedTemplateParameter(Decl *D, unsigned Index) {
switch (D->getKind()) {
case Decl::Kind::CXXRecord:
return cast<CXXRecordDecl>(D)
->getDescribedTemplate()
->getTemplateParameters();
case Decl::Kind::BuiltinTemplate:
case Decl::Kind::ClassTemplate:
return cast<ClassTemplateDecl>(D)->getTemplateParameters();
case Decl::Kind::Concept:
case Decl::Kind::FunctionTemplate:
case Decl::Kind::TemplateTemplateParm:
case Decl::Kind::TypeAliasTemplate:
case Decl::Kind::VarTemplate:
return {cast<TemplateDecl>(D)->getTemplateParameters()->getParam(Index),
{}};
case Decl::Kind::ClassTemplateSpecialization: {
const auto *CTSD = cast<ClassTemplateSpecializationDecl>(D);
auto P = CTSD->getSpecializedTemplateOrPartial();
TemplateParameterList *TPL;
if (const auto *CTPSD =
dyn_cast<ClassTemplatePartialSpecializationDecl *>(P))
return CTPSD->getTemplateParameters();
return cast<ClassTemplateDecl *>(P)->getTemplateParameters();
TPL = CTPSD->getTemplateParameters();
else
TPL = cast<ClassTemplateDecl *>(P)->getTemplateParameters();
return {TPL->getParam(Index), CTSD->getTemplateArgs()[Index]};
}
case Decl::Kind::VarTemplateSpecialization: {
const auto *VTSD = cast<VarTemplateSpecializationDecl>(D);
auto P = VTSD->getSpecializedTemplateOrPartial();
TemplateParameterList *TPL;
if (const auto *VTPSD = dyn_cast<VarTemplatePartialSpecializationDecl *>(P))
TPL = VTPSD->getTemplateParameters();
else
TPL = cast<VarTemplateDecl *>(P)->getTemplateParameters();
return {TPL->getParam(Index), VTSD->getTemplateArgs()[Index]};
}
case Decl::Kind::ClassTemplatePartialSpecialization:
return cast<ClassTemplatePartialSpecializationDecl>(D)
->getTemplateParameters();
case Decl::Kind::TypeAliasTemplate:
return cast<TypeAliasTemplateDecl>(D)->getTemplateParameters();
case Decl::Kind::BuiltinTemplate:
return cast<BuiltinTemplateDecl>(D)->getTemplateParameters();
return {cast<ClassTemplatePartialSpecializationDecl>(D)
->getTemplateParameters()
->getParam(Index),
{}};
case Decl::Kind::VarTemplatePartialSpecialization:
return {cast<VarTemplatePartialSpecializationDecl>(D)
->getTemplateParameters()
->getParam(Index),
{}};
// This is used as the AssociatedDecl for placeholder type deduction.
case Decl::TemplateTypeParm:
return {cast<NamedDecl>(D), {}};
// FIXME: Always use the template decl as the AssociatedDecl.
case Decl::Kind::CXXRecord:
return getReplacedTemplateParameter(
cast<CXXRecordDecl>(D)->getDescribedClassTemplate(), Index);
case Decl::Kind::CXXDeductionGuide:
case Decl::Kind::CXXConversion:
case Decl::Kind::CXXConstructor:
case Decl::Kind::CXXDestructor:
case Decl::Kind::CXXMethod:
case Decl::Kind::Function:
return cast<FunctionDecl>(D)
->getTemplateSpecializationInfo()
->getTemplate()
->getTemplateParameters();
case Decl::Kind::FunctionTemplate:
return cast<FunctionTemplateDecl>(D)->getTemplateParameters();
case Decl::Kind::VarTemplate:
return cast<VarTemplateDecl>(D)->getTemplateParameters();
case Decl::Kind::VarTemplateSpecialization: {
const auto *VTSD = cast<VarTemplateSpecializationDecl>(D);
auto P = VTSD->getSpecializedTemplateOrPartial();
if (const auto *VTPSD = dyn_cast<VarTemplatePartialSpecializationDecl *>(P))
return VTPSD->getTemplateParameters();
return cast<VarTemplateDecl *>(P)->getTemplateParameters();
}
case Decl::Kind::VarTemplatePartialSpecialization:
return cast<VarTemplatePartialSpecializationDecl>(D)
->getTemplateParameters();
case Decl::Kind::TemplateTemplateParm:
return cast<TemplateTemplateParmDecl>(D)->getTemplateParameters();
case Decl::Kind::Concept:
return cast<ConceptDecl>(D)->getTemplateParameters();
return getReplacedTemplateParameter(
cast<FunctionDecl>(D)->getTemplateSpecializationInfo()->getTemplate(),
Index);
default:
llvm_unreachable("Unhandled templated declaration kind");
}

View File

@ -1727,7 +1727,7 @@ SizeOfPackExpr *SizeOfPackExpr::CreateDeserialized(ASTContext &Context,
NonTypeTemplateParmDecl *SubstNonTypeTemplateParmExpr::getParameter() const {
return cast<NonTypeTemplateParmDecl>(
getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]);
std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index)));
}
PackIndexingExpr *PackIndexingExpr::Create(
@ -1793,7 +1793,7 @@ SubstNonTypeTemplateParmPackExpr::SubstNonTypeTemplateParmPackExpr(
NonTypeTemplateParmDecl *
SubstNonTypeTemplateParmPackExpr::getParameterPack() const {
return cast<NonTypeTemplateParmDecl>(
getReplacedTemplateParameterList(getAssociatedDecl())->asArray()[Index]);
std::get<0>(getReplacedTemplateParameter(getAssociatedDecl(), Index)));
}
TemplateArgument SubstNonTypeTemplateParmPackExpr::getArgumentPack() const {

View File

@ -64,16 +64,14 @@ SubstTemplateTemplateParmPackStorage::getArgumentPack() const {
TemplateTemplateParmDecl *
SubstTemplateTemplateParmPackStorage::getParameterPack() const {
return cast<TemplateTemplateParmDecl>(
getReplacedTemplateParameterList(getAssociatedDecl())
->asArray()[Bits.Index]);
return cast<TemplateTemplateParmDecl>(std::get<0>(
getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index)));
}
TemplateTemplateParmDecl *
SubstTemplateTemplateParmStorage::getParameter() const {
return cast<TemplateTemplateParmDecl>(
getReplacedTemplateParameterList(getAssociatedDecl())
->asArray()[Bits.Index]);
return cast<TemplateTemplateParmDecl>(std::get<0>(
getReplacedTemplateParameter(getAssociatedDecl(), Bits.Index)));
}
void SubstTemplateTemplateParmStorage::Profile(llvm::FoldingSetNodeID &ID) {

View File

@ -4436,14 +4436,6 @@ IdentifierInfo *TemplateTypeParmType::getIdentifier() const {
return isCanonicalUnqualified() ? nullptr : getDecl()->getIdentifier();
}
static const TemplateTypeParmDecl *getReplacedParameter(Decl *D,
unsigned Index) {
if (const auto *TTP = dyn_cast<TemplateTypeParmDecl>(D))
return TTP;
return cast<TemplateTypeParmDecl>(
getReplacedTemplateParameterList(D)->getParam(Index));
}
SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement,
Decl *AssociatedDecl,
unsigned Index,
@ -4466,7 +4458,8 @@ SubstTemplateTypeParmType::SubstTemplateTypeParmType(QualType Replacement,
const TemplateTypeParmDecl *
SubstTemplateTypeParmType::getReplacedParameter() const {
return ::getReplacedParameter(getAssociatedDecl(), getIndex());
return cast<TemplateTypeParmDecl>(std::get<0>(
getReplacedTemplateParameter(getAssociatedDecl(), getIndex())));
}
void SubstTemplateTypeParmType::Profile(llvm::FoldingSetNodeID &ID,
@ -4532,7 +4525,8 @@ bool SubstTemplateTypeParmPackType::getFinal() const {
const TemplateTypeParmDecl *
SubstTemplateTypeParmPackType::getReplacedParameter() const {
return ::getReplacedParameter(getAssociatedDecl(), getIndex());
return cast<TemplateTypeParmDecl>(std::get<0>(
getReplacedTemplateParameter(getAssociatedDecl(), getIndex())));
}
IdentifierInfo *SubstTemplateTypeParmPackType::getIdentifier() const {

View File

@ -16364,16 +16364,21 @@ ExprResult TreeTransform<Derived>::TransformSubstNonTypeTemplateParmExpr(
AssociatedDecl == E->getAssociatedDecl())
return E;
auto getParamAndType = [Index = E->getIndex()](Decl *AssociatedDecl)
-> std::tuple<NonTypeTemplateParmDecl *, QualType> {
auto [PDecl, Arg] = getReplacedTemplateParameter(AssociatedDecl, Index);
auto *Param = cast<NonTypeTemplateParmDecl>(PDecl);
return {Param, Arg.isNull() ? Param->getType()
: Arg.getNonTypeTemplateArgumentType()};
};
// If the replacement expression did not change, and the parameter type
// did not change, we can skip the semantic action because it would
// produce the same result anyway.
auto *Param = cast<NonTypeTemplateParmDecl>(
getReplacedTemplateParameterList(AssociatedDecl)
->asArray()[E->getIndex()]);
if (QualType ParamType = Param->getType();
!SemaRef.Context.hasSameType(ParamType, E->getParameter()->getType()) ||
if (auto [Param, ParamType] = getParamAndType(AssociatedDecl);
!SemaRef.Context.hasSameType(
ParamType, std::get<1>(getParamAndType(E->getAssociatedDecl()))) ||
Replacement.get() != OrigReplacement) {
// When transforming the replacement expression previously, all Sema
// specific annotations, such as implicit casts, are discarded. Calling the
// corresponding sema action is necessary to recover those. Otherwise,

View File

@ -0,0 +1,24 @@
// RUN: %clang_cc1 %s -O0 -disable-llvm-passes -triple=x86_64 -std=c++20 -emit-llvm -o - | FileCheck %s
namespace GH161029_regression1 {
template <class _Fp> auto f(int) { _Fp{}(0); }
template <class _Fp, int... _Js> void g() {
(..., f<_Fp>(_Js));
}
enum E { k };
template <int, E> struct ElementAt;
template <E First> struct ElementAt<0, First> {
static int value;
};
template <typename T, T Item> struct TagSet {
template <int Index> using Tag = ElementAt<Index, Item>;
};
template <typename TagSet> struct S {
void U() { (void)TagSet::template Tag<0>::value; }
};
S<TagSet<E, k>> s;
void h() {
g<decltype([](auto) -> void { s.U(); }), 0>();
}
// CHECK: call void @_ZN20GH161029_regression11SINS_6TagSetINS_1EELS2_0EEEE1UEv
}