[NFC][Clang][OpenMP] Add helper functions/utils for finding/comparing attach base-ptrs. (#155625)
These have been pulled out of the codegen PR #153683, to reduce the size of that PR.
This commit is contained in:
parent
d45a135918
commit
777eea0732
@ -5816,6 +5816,12 @@ public:
|
||||
ValueDecl *getAssociatedDeclaration() const {
|
||||
return AssociatedDeclaration;
|
||||
}
|
||||
|
||||
bool operator==(const MappableComponent &Other) const {
|
||||
return AssociatedExpressionNonContiguousPr ==
|
||||
Other.AssociatedExpressionNonContiguousPr &&
|
||||
AssociatedDeclaration == Other.AssociatedDeclaration;
|
||||
}
|
||||
};
|
||||
|
||||
// List of components of an expression. This first one is the whole
|
||||
@ -5829,6 +5835,95 @@ public:
|
||||
using MappableExprComponentLists = SmallVector<MappableExprComponentList, 8>;
|
||||
using MappableExprComponentListsRef = ArrayRef<MappableExprComponentList>;
|
||||
|
||||
// Hash function to allow usage as DenseMap keys.
|
||||
friend llvm::hash_code hash_value(const MappableComponent &MC) {
|
||||
return llvm::hash_combine(MC.getAssociatedExpression(),
|
||||
MC.getAssociatedDeclaration(),
|
||||
MC.isNonContiguous());
|
||||
}
|
||||
|
||||
public:
|
||||
/// Get the type of an element of a ComponentList Expr \p Exp.
|
||||
///
|
||||
/// For something like the following:
|
||||
/// ```c
|
||||
/// int *p, **p;
|
||||
/// ```
|
||||
/// The types for the following Exprs would be:
|
||||
/// Expr | Type
|
||||
/// ---------|-----------
|
||||
/// p | int *
|
||||
/// *p | int
|
||||
/// p[0] | int
|
||||
/// p[0:1] | int
|
||||
/// pp | int **
|
||||
/// pp[0] | int *
|
||||
/// pp[0:1] | int *
|
||||
/// Note: this assumes that if \p Exp is an array-section, it is contiguous.
|
||||
static QualType getComponentExprElementType(const Expr *Exp);
|
||||
|
||||
/// Find the attach pointer expression from a list of mappable expression
|
||||
/// components.
|
||||
///
|
||||
/// This function traverses the component list to find the first
|
||||
/// expression that has a pointer type, which represents the attach
|
||||
/// base pointer expr for the current component-list.
|
||||
///
|
||||
/// For example, given the following:
|
||||
///
|
||||
/// ```c
|
||||
/// struct S {
|
||||
/// int a;
|
||||
/// int b[10];
|
||||
/// int c[10][10];
|
||||
/// int *p;
|
||||
/// int **pp;
|
||||
/// }
|
||||
/// S s, *ps, **pps, *(pas[10]), ***ppps;
|
||||
/// int i;
|
||||
/// ```
|
||||
///
|
||||
/// The base-pointers for the following map operands would be:
|
||||
/// map list-item | attach base-pointer | attach base-pointer
|
||||
/// | for directives except | target_update (if
|
||||
/// | target_update | different)
|
||||
/// ----------------|-----------------------|---------------------
|
||||
/// s | N/A |
|
||||
/// s.a | N/A |
|
||||
/// s.p | N/A |
|
||||
/// ps | N/A |
|
||||
/// ps->p | ps |
|
||||
/// ps[1] | ps |
|
||||
/// *(ps + 1) | ps |
|
||||
/// (ps + 1)[1] | ps |
|
||||
/// ps[1:10] | ps |
|
||||
/// ps->b[10] | ps |
|
||||
/// ps->p[10] | ps->p |
|
||||
/// ps->c[1][2] | ps |
|
||||
/// ps->c[1:2][2] | (error diagnostic) | N/A, TODO: ps
|
||||
/// ps->c[1:1][2] | ps | N/A, TODO: ps
|
||||
/// pps[1][2] | pps[1] |
|
||||
/// pps[1:1][2] | pps[1:1] | N/A, TODO: pps[1:1]
|
||||
/// pps[1:i][2] | pps[1:i] | N/A, TODO: pps[1:i]
|
||||
/// pps[1:2][2] | (error diagnostic) | N/A
|
||||
/// pps[1]->p | pps[1] |
|
||||
/// pps[1]->p[10] | pps[1] |
|
||||
/// pas[1] | N/A |
|
||||
/// pas[1][2] | pas[1] |
|
||||
/// ppps[1][2] | ppps[1] |
|
||||
/// ppps[1][2][3] | ppps[1][2] |
|
||||
/// ppps[1][2:1][3] | ppps[1][2:1] | N/A, TODO: ppps[1][2:1]
|
||||
/// ppps[1][2:2][3] | (error diagnostic) | N/A
|
||||
/// Returns a pair of the attach pointer expression and its depth in the
|
||||
/// component list.
|
||||
/// TODO: This may need to be updated to handle ref_ptr/ptee cases for byref
|
||||
/// map operands.
|
||||
/// TODO: Handle cases for target-update, where the list-item is a
|
||||
/// non-contiguous array-section that still has a base-pointer.
|
||||
static std::pair<const Expr *, std::optional<size_t>>
|
||||
findAttachPtrExpr(MappableExprComponentListRef Components,
|
||||
OpenMPDirectiveKind CurDirKind);
|
||||
|
||||
protected:
|
||||
// Return the total number of elements in a list of component lists.
|
||||
static unsigned
|
||||
|
||||
@ -312,6 +312,14 @@ bool isOpenMPTargetExecutionDirective(OpenMPDirectiveKind DKind);
|
||||
/// otherwise - false.
|
||||
bool isOpenMPTargetDataManagementDirective(OpenMPDirectiveKind DKind);
|
||||
|
||||
/// Checks if the specified directive is a map-entering target directive.
|
||||
/// \param DKind Specified directive.
|
||||
/// \return true - the directive is a map-entering target directive like
|
||||
/// 'omp target', 'omp target data', 'omp target enter data',
|
||||
/// 'omp target parallel', etc. (excludes 'omp target exit data', 'omp target
|
||||
/// update') otherwise - false.
|
||||
bool isOpenMPTargetMapEnteringDirective(OpenMPDirectiveKind DKind);
|
||||
|
||||
/// Checks if the specified composite/combined directive constitutes a teams
|
||||
/// directive in the outermost nest. For example
|
||||
/// 'omp teams distribute' or 'omp teams distribute parallel for'.
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include "clang/AST/Attr.h"
|
||||
#include "clang/AST/Decl.h"
|
||||
#include "clang/AST/DeclOpenMP.h"
|
||||
#include "clang/AST/ExprOpenMP.h"
|
||||
#include "clang/Basic/LLVM.h"
|
||||
#include "clang/Basic/OpenMPKinds.h"
|
||||
#include "clang/Basic/TargetInfo.h"
|
||||
@ -1159,6 +1160,77 @@ unsigned OMPClauseMappableExprCommon::getUniqueDeclarationsTotalNumber(
|
||||
return UniqueDecls.size();
|
||||
}
|
||||
|
||||
QualType
|
||||
OMPClauseMappableExprCommon::getComponentExprElementType(const Expr *Exp) {
|
||||
assert(!isa<OMPArrayShapingExpr>(Exp) &&
|
||||
"Cannot get element-type from array-shaping expr.");
|
||||
|
||||
// Unless we are handling array-section expressions, including
|
||||
// array-subscripts, derefs, we can rely on getType.
|
||||
if (!isa<ArraySectionExpr>(Exp))
|
||||
return Exp->getType().getNonReferenceType().getCanonicalType();
|
||||
|
||||
// For array-sections, we need to find the type of one element of
|
||||
// the section.
|
||||
const auto *OASE = cast<ArraySectionExpr>(Exp);
|
||||
|
||||
QualType BaseType = ArraySectionExpr::getBaseOriginalType(OASE->getBase());
|
||||
|
||||
QualType ElemTy;
|
||||
if (const auto *ATy = BaseType->getAsArrayTypeUnsafe())
|
||||
ElemTy = ATy->getElementType();
|
||||
else
|
||||
ElemTy = BaseType->getPointeeType();
|
||||
|
||||
ElemTy = ElemTy.getNonReferenceType().getCanonicalType();
|
||||
return ElemTy;
|
||||
}
|
||||
|
||||
std::pair<const Expr *, std::optional<size_t>>
|
||||
OMPClauseMappableExprCommon::findAttachPtrExpr(
|
||||
MappableExprComponentListRef Components, OpenMPDirectiveKind CurDirKind) {
|
||||
|
||||
// If we only have a single component, we have a map like "map(p)", which
|
||||
// cannot have a base-pointer.
|
||||
if (Components.size() < 2)
|
||||
return {nullptr, std::nullopt};
|
||||
|
||||
// Only check for non-contiguous sections on target_update, since we can
|
||||
// assume array-sections are contiguous on maps on other constructs, even if
|
||||
// we are not sure of it at compile-time, like for a[1:x][2].
|
||||
if (Components.back().isNonContiguous() && CurDirKind == OMPD_target_update)
|
||||
return {nullptr, std::nullopt};
|
||||
|
||||
// To find the attach base-pointer, we start with the second component,
|
||||
// stripping away one component at a time, until we reach a pointer Expr
|
||||
// (that is not a binary operator). The first such pointer should be the
|
||||
// attach base-pointer for the component list.
|
||||
for (auto [I, Component] : llvm::enumerate(Components)) {
|
||||
// Skip past the first component.
|
||||
if (I == 0)
|
||||
continue;
|
||||
|
||||
const Expr *CurExpr = Component.getAssociatedExpression();
|
||||
if (!CurExpr)
|
||||
break;
|
||||
|
||||
// If CurExpr is something like `p + 10`, we need to ignore it, since
|
||||
// we are looking for `p`.
|
||||
if (isa<BinaryOperator>(CurExpr))
|
||||
continue;
|
||||
|
||||
// Keep going until we reach an Expr of pointer type.
|
||||
QualType CurType = getComponentExprElementType(CurExpr);
|
||||
if (!CurType->isPointerType())
|
||||
continue;
|
||||
|
||||
// We have found a pointer Expr. This must be the attach pointer.
|
||||
return {CurExpr, Components.size() - I};
|
||||
}
|
||||
|
||||
return {nullptr, std::nullopt};
|
||||
}
|
||||
|
||||
OMPMapClause *OMPMapClause::Create(
|
||||
const ASTContext &C, const OMPVarListLocTy &Locs, ArrayRef<Expr *> Vars,
|
||||
ArrayRef<ValueDecl *> Declarations,
|
||||
|
||||
@ -677,6 +677,11 @@ bool clang::isOpenMPTargetDataManagementDirective(OpenMPDirectiveKind DKind) {
|
||||
DKind == OMPD_target_exit_data || DKind == OMPD_target_update;
|
||||
}
|
||||
|
||||
bool clang::isOpenMPTargetMapEnteringDirective(OpenMPDirectiveKind DKind) {
|
||||
return DKind == OMPD_target_data || DKind == OMPD_target_enter_data ||
|
||||
isOpenMPTargetExecutionDirective(DKind);
|
||||
}
|
||||
|
||||
bool clang::isOpenMPNestingTeamsDirective(OpenMPDirectiveKind DKind) {
|
||||
if (DKind == OMPD_teams)
|
||||
return true;
|
||||
|
||||
@ -6799,6 +6799,240 @@ LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
|
||||
// code for that information.
|
||||
class MappableExprsHandler {
|
||||
public:
|
||||
/// Custom comparator for attach-pointer expressions that compares them by
|
||||
/// complexity (i.e. their component-depth) first, then by the order in which
|
||||
/// they were computed by collectAttachPtrExprInfo(), if they are semantically
|
||||
/// different.
|
||||
struct AttachPtrExprComparator {
|
||||
const MappableExprsHandler *Handler;
|
||||
// Cache of previous equality comparison results.
|
||||
mutable llvm::DenseMap<std::pair<const Expr *, const Expr *>, bool>
|
||||
CachedEqualityComparisons;
|
||||
|
||||
AttachPtrExprComparator(const MappableExprsHandler *H) : Handler(H) {}
|
||||
|
||||
// Return true iff LHS is "less than" RHS.
|
||||
bool operator()(const Expr *LHS, const Expr *RHS) const {
|
||||
if (LHS == RHS)
|
||||
return false;
|
||||
|
||||
// First, compare by complexity (depth)
|
||||
auto ItLHS = Handler->AttachPtrComponentDepthMap.find(LHS);
|
||||
auto ItRHS = Handler->AttachPtrComponentDepthMap.find(RHS);
|
||||
|
||||
std::optional<size_t> DepthLHS =
|
||||
(ItLHS != Handler->AttachPtrComponentDepthMap.end()) ? ItLHS->second
|
||||
: std::nullopt;
|
||||
std::optional<size_t> DepthRHS =
|
||||
(ItRHS != Handler->AttachPtrComponentDepthMap.end()) ? ItRHS->second
|
||||
: std::nullopt;
|
||||
|
||||
// std::nullopt (no attach pointer) has lowest complexity
|
||||
if (!DepthLHS.has_value() && !DepthRHS.has_value()) {
|
||||
// Both have same complexity, now check semantic equality
|
||||
if (areEqual(LHS, RHS))
|
||||
return false;
|
||||
// Different semantically, compare by computation order
|
||||
return wasComputedBefore(LHS, RHS);
|
||||
}
|
||||
if (!DepthLHS.has_value())
|
||||
return true; // LHS has lower complexity
|
||||
if (!DepthRHS.has_value())
|
||||
return false; // RHS has lower complexity
|
||||
|
||||
// Both have values, compare by depth (lower depth = lower complexity)
|
||||
if (DepthLHS.value() != DepthRHS.value())
|
||||
return DepthLHS.value() < DepthRHS.value();
|
||||
|
||||
// Same complexity, now check semantic equality
|
||||
if (areEqual(LHS, RHS))
|
||||
return false;
|
||||
// Different semantically, compare by computation order
|
||||
return wasComputedBefore(LHS, RHS);
|
||||
}
|
||||
|
||||
public:
|
||||
/// Return true if \p LHS and \p RHS are semantically equal. Uses pre-cached
|
||||
/// results, if available, otherwise does a recursive semantic comparison.
|
||||
bool areEqual(const Expr *LHS, const Expr *RHS) const {
|
||||
// Check cache first for faster lookup
|
||||
auto CachedResultIt = CachedEqualityComparisons.find({LHS, RHS});
|
||||
if (CachedResultIt != CachedEqualityComparisons.end())
|
||||
return CachedResultIt->second;
|
||||
|
||||
bool ComparisonResult = areSemanticallyEqual(LHS, RHS);
|
||||
|
||||
// Cache the result for future lookups (both orders since semantic
|
||||
// equality is commutative)
|
||||
CachedEqualityComparisons[{LHS, RHS}] = ComparisonResult;
|
||||
CachedEqualityComparisons[{RHS, LHS}] = ComparisonResult;
|
||||
return ComparisonResult;
|
||||
}
|
||||
|
||||
/// Compare the two attach-ptr expressions by their computation order.
|
||||
/// Returns true iff LHS was computed before RHS by
|
||||
/// collectAttachPtrExprInfo().
|
||||
bool wasComputedBefore(const Expr *LHS, const Expr *RHS) const {
|
||||
const size_t &OrderLHS = Handler->AttachPtrComputationOrderMap.at(LHS);
|
||||
const size_t &OrderRHS = Handler->AttachPtrComputationOrderMap.at(RHS);
|
||||
|
||||
return OrderLHS < OrderRHS;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Helper function to compare attach-pointer expressions semantically.
|
||||
/// This function handles various expression types that can be part of an
|
||||
/// attach-pointer.
|
||||
/// TODO: Not urgent, but we should ideally return true when comparing
|
||||
/// `p[10]`, `*(p + 10)`, `*(p + 5 + 5)`, `p[10:1]` etc.
|
||||
bool areSemanticallyEqual(const Expr *LHS, const Expr *RHS) const {
|
||||
if (LHS == RHS)
|
||||
return true;
|
||||
|
||||
// If only one is null, they aren't equal
|
||||
if (!LHS || !RHS)
|
||||
return false;
|
||||
|
||||
ASTContext &Ctx = Handler->CGF.getContext();
|
||||
// Strip away parentheses and no-op casts to get to the core expression
|
||||
LHS = LHS->IgnoreParenNoopCasts(Ctx);
|
||||
RHS = RHS->IgnoreParenNoopCasts(Ctx);
|
||||
|
||||
// Direct pointer comparison of the underlying expressions
|
||||
if (LHS == RHS)
|
||||
return true;
|
||||
|
||||
// Check if the expression classes match
|
||||
if (LHS->getStmtClass() != RHS->getStmtClass())
|
||||
return false;
|
||||
|
||||
// Handle DeclRefExpr (variable references)
|
||||
if (const auto *LD = dyn_cast<DeclRefExpr>(LHS)) {
|
||||
const auto *RD = dyn_cast<DeclRefExpr>(RHS);
|
||||
if (!RD)
|
||||
return false;
|
||||
return LD->getDecl()->getCanonicalDecl() ==
|
||||
RD->getDecl()->getCanonicalDecl();
|
||||
}
|
||||
|
||||
// Handle ArraySubscriptExpr (array indexing like a[i])
|
||||
if (const auto *LA = dyn_cast<ArraySubscriptExpr>(LHS)) {
|
||||
const auto *RA = dyn_cast<ArraySubscriptExpr>(RHS);
|
||||
if (!RA)
|
||||
return false;
|
||||
return areSemanticallyEqual(LA->getBase(), RA->getBase()) &&
|
||||
areSemanticallyEqual(LA->getIdx(), RA->getIdx());
|
||||
}
|
||||
|
||||
// Handle MemberExpr (member access like s.m or p->m)
|
||||
if (const auto *LM = dyn_cast<MemberExpr>(LHS)) {
|
||||
const auto *RM = dyn_cast<MemberExpr>(RHS);
|
||||
if (!RM)
|
||||
return false;
|
||||
if (LM->getMemberDecl()->getCanonicalDecl() !=
|
||||
RM->getMemberDecl()->getCanonicalDecl())
|
||||
return false;
|
||||
return areSemanticallyEqual(LM->getBase(), RM->getBase());
|
||||
}
|
||||
|
||||
// Handle UnaryOperator (unary operations like *p, &x, etc.)
|
||||
if (const auto *LU = dyn_cast<UnaryOperator>(LHS)) {
|
||||
const auto *RU = dyn_cast<UnaryOperator>(RHS);
|
||||
if (!RU)
|
||||
return false;
|
||||
if (LU->getOpcode() != RU->getOpcode())
|
||||
return false;
|
||||
return areSemanticallyEqual(LU->getSubExpr(), RU->getSubExpr());
|
||||
}
|
||||
|
||||
// Handle BinaryOperator (binary operations like p + offset)
|
||||
if (const auto *LB = dyn_cast<BinaryOperator>(LHS)) {
|
||||
const auto *RB = dyn_cast<BinaryOperator>(RHS);
|
||||
if (!RB)
|
||||
return false;
|
||||
if (LB->getOpcode() != RB->getOpcode())
|
||||
return false;
|
||||
return areSemanticallyEqual(LB->getLHS(), RB->getLHS()) &&
|
||||
areSemanticallyEqual(LB->getRHS(), RB->getRHS());
|
||||
}
|
||||
|
||||
// Handle ArraySectionExpr (array sections like a[0:1])
|
||||
// Attach pointers should not contain array-sections, but currently we
|
||||
// don't emit an error.
|
||||
if (const auto *LAS = dyn_cast<ArraySectionExpr>(LHS)) {
|
||||
const auto *RAS = dyn_cast<ArraySectionExpr>(RHS);
|
||||
if (!RAS)
|
||||
return false;
|
||||
return areSemanticallyEqual(LAS->getBase(), RAS->getBase()) &&
|
||||
areSemanticallyEqual(LAS->getLowerBound(),
|
||||
RAS->getLowerBound()) &&
|
||||
areSemanticallyEqual(LAS->getLength(), RAS->getLength());
|
||||
}
|
||||
|
||||
// Handle CastExpr (explicit casts)
|
||||
if (const auto *LC = dyn_cast<CastExpr>(LHS)) {
|
||||
const auto *RC = dyn_cast<CastExpr>(RHS);
|
||||
if (!RC)
|
||||
return false;
|
||||
if (LC->getCastKind() != RC->getCastKind())
|
||||
return false;
|
||||
return areSemanticallyEqual(LC->getSubExpr(), RC->getSubExpr());
|
||||
}
|
||||
|
||||
// Handle CXXThisExpr (this pointer)
|
||||
if (isa<CXXThisExpr>(LHS) && isa<CXXThisExpr>(RHS))
|
||||
return true;
|
||||
|
||||
// Handle IntegerLiteral (integer constants)
|
||||
if (const auto *LI = dyn_cast<IntegerLiteral>(LHS)) {
|
||||
const auto *RI = dyn_cast<IntegerLiteral>(RHS);
|
||||
if (!RI)
|
||||
return false;
|
||||
return LI->getValue() == RI->getValue();
|
||||
}
|
||||
|
||||
// Handle CharacterLiteral (character constants)
|
||||
if (const auto *LC = dyn_cast<CharacterLiteral>(LHS)) {
|
||||
const auto *RC = dyn_cast<CharacterLiteral>(RHS);
|
||||
if (!RC)
|
||||
return false;
|
||||
return LC->getValue() == RC->getValue();
|
||||
}
|
||||
|
||||
// Handle FloatingLiteral (floating point constants)
|
||||
if (const auto *LF = dyn_cast<FloatingLiteral>(LHS)) {
|
||||
const auto *RF = dyn_cast<FloatingLiteral>(RHS);
|
||||
if (!RF)
|
||||
return false;
|
||||
// Use bitwise comparison for floating point literals
|
||||
return LF->getValue().bitwiseIsEqual(RF->getValue());
|
||||
}
|
||||
|
||||
// Handle StringLiteral (string constants)
|
||||
if (const auto *LS = dyn_cast<StringLiteral>(LHS)) {
|
||||
const auto *RS = dyn_cast<StringLiteral>(RHS);
|
||||
if (!RS)
|
||||
return false;
|
||||
return LS->getString() == RS->getString();
|
||||
}
|
||||
|
||||
// Handle CXXNullPtrLiteralExpr (nullptr)
|
||||
if (isa<CXXNullPtrLiteralExpr>(LHS) && isa<CXXNullPtrLiteralExpr>(RHS))
|
||||
return true;
|
||||
|
||||
// Handle CXXBoolLiteralExpr (true/false)
|
||||
if (const auto *LB = dyn_cast<CXXBoolLiteralExpr>(LHS)) {
|
||||
const auto *RB = dyn_cast<CXXBoolLiteralExpr>(RHS);
|
||||
if (!RB)
|
||||
return false;
|
||||
return LB->getValue() == RB->getValue();
|
||||
}
|
||||
|
||||
// Fallback for other forms - use the existing comparison method
|
||||
return Expr::isSameComparisonOperand(LHS, RHS);
|
||||
}
|
||||
};
|
||||
|
||||
/// Get the offset of the OMP_MAP_MEMBER_OF field.
|
||||
static unsigned getFlagMemberOffset() {
|
||||
unsigned Offset = 0;
|
||||
@ -6876,6 +7110,45 @@ public:
|
||||
bool HasCompleteRecord = false;
|
||||
};
|
||||
|
||||
/// A struct to store the attach pointer and pointee information, to be used
|
||||
/// when emitting an attach entry.
|
||||
struct AttachInfoTy {
|
||||
Address AttachPtrAddr = Address::invalid();
|
||||
Address AttachPteeAddr = Address::invalid();
|
||||
const ValueDecl *AttachPtrDecl = nullptr;
|
||||
const Expr *AttachMapExpr = nullptr;
|
||||
|
||||
bool isValid() const {
|
||||
return AttachPtrAddr.isValid() && AttachPteeAddr.isValid();
|
||||
}
|
||||
};
|
||||
|
||||
/// Check if there's any component list where the attach pointer expression
|
||||
/// matches the given captured variable.
|
||||
bool hasAttachEntryForCapturedVar(const ValueDecl *VD) const {
|
||||
for (const auto &AttachEntry : AttachPtrExprMap) {
|
||||
if (AttachEntry.second) {
|
||||
// Check if the attach pointer expression is a DeclRefExpr that
|
||||
// references the captured variable
|
||||
if (const auto *DRE = dyn_cast<DeclRefExpr>(AttachEntry.second))
|
||||
if (DRE->getDecl() == VD)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Get the previously-cached attach pointer for a component list, if-any.
|
||||
const Expr *getAttachPtrExpr(
|
||||
OMPClauseMappableExprCommon::MappableExprComponentListRef Components)
|
||||
const {
|
||||
auto It = AttachPtrExprMap.find(Components);
|
||||
if (It != AttachPtrExprMap.end())
|
||||
return It->second;
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Kind that defines how a device pointer has to be returned.
|
||||
struct MapInfo {
|
||||
@ -6948,6 +7221,27 @@ private:
|
||||
/// Map between lambda declarations and their map type.
|
||||
llvm::DenseMap<const ValueDecl *, const OMPMapClause *> LambdasMap;
|
||||
|
||||
/// Map from component lists to their attach pointer expressions.
|
||||
llvm::DenseMap<OMPClauseMappableExprCommon::MappableExprComponentListRef,
|
||||
const Expr *>
|
||||
AttachPtrExprMap;
|
||||
|
||||
/// Map from attach pointer expressions to their component depth.
|
||||
/// nullptr key has std::nullopt depth. This can be used to order attach-ptr
|
||||
/// expressions with increasing/decreasing depth.
|
||||
/// The component-depth of `nullptr` (i.e. no attach-ptr) is `std::nullopt`.
|
||||
/// TODO: Not urgent, but we should ideally use the number of pointer
|
||||
/// dereferences in an expr as an indicator of its complexity, instead of the
|
||||
/// component-depth. That would be needed for us to treat `p[1]`, `*(p + 10)`,
|
||||
/// `*(p + 5 + 5)` together.
|
||||
llvm::DenseMap<const Expr *, std::optional<size_t>>
|
||||
AttachPtrComponentDepthMap = {{nullptr, std::nullopt}};
|
||||
|
||||
/// Map from attach pointer expressions to the order they were computed in, in
|
||||
/// collectAttachPtrExprInfo().
|
||||
llvm::DenseMap<const Expr *, size_t> AttachPtrComputationOrderMap = {
|
||||
{nullptr, 0}};
|
||||
|
||||
llvm::Value *getExprTypeSize(const Expr *E) const {
|
||||
QualType ExprTy = E->getType().getCanonicalType();
|
||||
|
||||
@ -8167,6 +8461,104 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the address corresponding to \p PointerExpr.
|
||||
static Address getAttachPtrAddr(const Expr *PointerExpr,
|
||||
CodeGenFunction &CGF) {
|
||||
assert(PointerExpr && "Cannot get addr from null attach-ptr expr");
|
||||
Address AttachPtrAddr = Address::invalid();
|
||||
|
||||
if (auto *DRE = dyn_cast<DeclRefExpr>(PointerExpr)) {
|
||||
// If the pointer is a variable, we can use its address directly.
|
||||
AttachPtrAddr = CGF.EmitLValue(DRE).getAddress();
|
||||
} else if (auto *OASE = dyn_cast<ArraySectionExpr>(PointerExpr)) {
|
||||
AttachPtrAddr =
|
||||
CGF.EmitArraySectionExpr(OASE, /*IsLowerBound=*/true).getAddress();
|
||||
} else if (auto *ASE = dyn_cast<ArraySubscriptExpr>(PointerExpr)) {
|
||||
AttachPtrAddr = CGF.EmitLValue(ASE).getAddress();
|
||||
} else if (auto *ME = dyn_cast<MemberExpr>(PointerExpr)) {
|
||||
AttachPtrAddr = CGF.EmitMemberExpr(ME).getAddress();
|
||||
} else if (auto *UO = dyn_cast<UnaryOperator>(PointerExpr)) {
|
||||
if (UO->getOpcode() == UO_Deref)
|
||||
AttachPtrAddr = CGF.EmitLValue(UO).getAddress();
|
||||
}
|
||||
assert(AttachPtrAddr.isValid() &&
|
||||
"Failed to get address for attach pointer expression");
|
||||
return AttachPtrAddr;
|
||||
}
|
||||
|
||||
/// Get the address of the attach pointer, and a load from it, to get the
|
||||
/// pointee base address.
|
||||
/// \return A pair containing AttachPtrAddr and AttachPteeBaseAddr. The pair
|
||||
/// contains invalid addresses if \p AttachPtrExpr is null.
|
||||
static std::pair<Address, Address>
|
||||
getAttachPtrAddrAndPteeBaseAddr(const Expr *AttachPtrExpr,
|
||||
CodeGenFunction &CGF) {
|
||||
|
||||
if (!AttachPtrExpr)
|
||||
return {Address::invalid(), Address::invalid()};
|
||||
|
||||
Address AttachPtrAddr = getAttachPtrAddr(AttachPtrExpr, CGF);
|
||||
assert(AttachPtrAddr.isValid() && "Invalid attach pointer addr");
|
||||
|
||||
QualType AttachPtrType =
|
||||
OMPClauseMappableExprCommon::getComponentExprElementType(AttachPtrExpr)
|
||||
.getCanonicalType();
|
||||
|
||||
Address AttachPteeBaseAddr = CGF.EmitLoadOfPointer(
|
||||
AttachPtrAddr, AttachPtrType->castAs<PointerType>());
|
||||
assert(AttachPteeBaseAddr.isValid() && "Invalid attach pointee base addr");
|
||||
|
||||
return {AttachPtrAddr, AttachPteeBaseAddr};
|
||||
}
|
||||
|
||||
/// Returns whether an attach entry should be emitted for a map on
|
||||
/// \p MapBaseDecl on the directive \p CurDir.
|
||||
static bool
|
||||
shouldEmitAttachEntry(const Expr *PointerExpr, const ValueDecl *MapBaseDecl,
|
||||
CodeGenFunction &CGF,
|
||||
llvm::PointerUnion<const OMPExecutableDirective *,
|
||||
const OMPDeclareMapperDecl *>
|
||||
CurDir) {
|
||||
if (!PointerExpr)
|
||||
return false;
|
||||
|
||||
// Pointer attachment is needed at map-entering time or for declare
|
||||
// mappers.
|
||||
if (!isa<const OMPDeclareMapperDecl *>(CurDir) &&
|
||||
!isOpenMPTargetMapEnteringDirective(
|
||||
cast<const OMPExecutableDirective *>(CurDir)->getDirectiveKind()))
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Computes the attach-ptr expr for \p Components, and updates various maps
|
||||
/// with the information.
|
||||
/// It internally calls OMPClauseMappableExprCommon::findAttachPtrExpr()
|
||||
/// with the OpenMPDirectiveKind extracted from \p CurDir.
|
||||
/// It updates AttachPtrComputationOrderMap, AttachPtrComponentDepthMap, and
|
||||
/// AttachPtrExprMap.
|
||||
void collectAttachPtrExprInfo(
|
||||
OMPClauseMappableExprCommon::MappableExprComponentListRef Components,
|
||||
llvm::PointerUnion<const OMPExecutableDirective *,
|
||||
const OMPDeclareMapperDecl *>
|
||||
CurDir) {
|
||||
|
||||
OpenMPDirectiveKind CurDirectiveID =
|
||||
isa<const OMPDeclareMapperDecl *>(CurDir)
|
||||
? OMPD_declare_mapper
|
||||
: cast<const OMPExecutableDirective *>(CurDir)->getDirectiveKind();
|
||||
|
||||
const auto &[AttachPtrExpr, Depth] =
|
||||
OMPClauseMappableExprCommon::findAttachPtrExpr(Components,
|
||||
CurDirectiveID);
|
||||
|
||||
AttachPtrComputationOrderMap.try_emplace(
|
||||
AttachPtrExpr, AttachPtrComputationOrderMap.size());
|
||||
AttachPtrComponentDepthMap.try_emplace(AttachPtrExpr, Depth);
|
||||
AttachPtrExprMap.try_emplace(Components, AttachPtrExpr);
|
||||
}
|
||||
|
||||
/// Generate all the base pointers, section pointers, sizes, map types, and
|
||||
/// mappers for the extracted mappable expressions (all included in \a
|
||||
/// CombinedInfo). Also, for each item that relates with a device pointer, a
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user