[OpenMP][Clang] Add codegen support for dyn_groupprivate clause (#152830)
This adds the codegen support for the dyn_groupprivate clause.
This commit is contained in:
parent
b4a61517a6
commit
540250ca7a
@ -10013,19 +10013,44 @@ static llvm::Value *emitDeviceID(
|
||||
return DeviceID;
|
||||
}
|
||||
|
||||
static llvm::Value *emitDynCGGroupMem(const OMPExecutableDirective &D,
|
||||
CodeGenFunction &CGF) {
|
||||
llvm::Value *DynCGroupMem = CGF.Builder.getInt32(0);
|
||||
static std::pair<llvm::Value *, OMPDynGroupprivateFallbackType>
|
||||
emitDynCGroupMem(const OMPExecutableDirective &D, CodeGenFunction &CGF) {
|
||||
llvm::Value *DynGP = CGF.Builder.getInt32(0);
|
||||
auto DynGPFallback = OMPDynGroupprivateFallbackType::Abort;
|
||||
|
||||
if (auto *DynMemClause = D.getSingleClause<OMPXDynCGroupMemClause>()) {
|
||||
CodeGenFunction::RunCleanupsScope DynCGroupMemScope(CGF);
|
||||
llvm::Value *DynCGroupMemVal = CGF.EmitScalarExpr(
|
||||
DynMemClause->getSize(), /*IgnoreResultAssign=*/true);
|
||||
DynCGroupMem = CGF.Builder.CreateIntCast(DynCGroupMemVal, CGF.Int32Ty,
|
||||
/*isSigned=*/false);
|
||||
if (auto *DynGPClause = D.getSingleClause<OMPDynGroupprivateClause>()) {
|
||||
CodeGenFunction::RunCleanupsScope DynGPScope(CGF);
|
||||
llvm::Value *DynGPVal =
|
||||
CGF.EmitScalarExpr(DynGPClause->getSize(), /*IgnoreResultAssign=*/true);
|
||||
DynGP = CGF.Builder.CreateIntCast(DynGPVal, CGF.Int32Ty,
|
||||
/*isSigned=*/false);
|
||||
auto FallbackModifier = DynGPClause->getDynGroupprivateFallbackModifier();
|
||||
switch (FallbackModifier) {
|
||||
case OMPC_DYN_GROUPPRIVATE_FALLBACK_abort:
|
||||
DynGPFallback = OMPDynGroupprivateFallbackType::Abort;
|
||||
break;
|
||||
case OMPC_DYN_GROUPPRIVATE_FALLBACK_null:
|
||||
DynGPFallback = OMPDynGroupprivateFallbackType::Null;
|
||||
break;
|
||||
case OMPC_DYN_GROUPPRIVATE_FALLBACK_default_mem:
|
||||
case OMPC_DYN_GROUPPRIVATE_FALLBACK_unknown:
|
||||
// This is the default for dyn_groupprivate.
|
||||
DynGPFallback = OMPDynGroupprivateFallbackType::DefaultMem;
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("Unknown fallback modifier for OpenMP dyn_groupprivate");
|
||||
}
|
||||
} else if (auto *OMPXDynCGClause =
|
||||
D.getSingleClause<OMPXDynCGroupMemClause>()) {
|
||||
CodeGenFunction::RunCleanupsScope DynCGMemScope(CGF);
|
||||
llvm::Value *DynCGMemVal = CGF.EmitScalarExpr(OMPXDynCGClause->getSize(),
|
||||
/*IgnoreResultAssign=*/true);
|
||||
DynGP = CGF.Builder.CreateIntCast(DynCGMemVal, CGF.Int32Ty,
|
||||
/*isSigned=*/false);
|
||||
}
|
||||
return DynCGroupMem;
|
||||
return {DynGP, DynGPFallback};
|
||||
}
|
||||
|
||||
static void genMapInfoForCaptures(
|
||||
MappableExprsHandler &MEHandler, CodeGenFunction &CGF,
|
||||
const CapturedStmt &CS, llvm::SmallVectorImpl<llvm::Value *> &CapturedVars,
|
||||
@ -10234,7 +10259,7 @@ static void emitTargetCallKernelLaunch(
|
||||
llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());
|
||||
llvm::Value *NumIterations =
|
||||
OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter);
|
||||
llvm::Value *DynCGGroupMem = emitDynCGGroupMem(D, CGF);
|
||||
auto [DynCGroupMem, DynCGroupMemFallback] = emitDynCGroupMem(D, CGF);
|
||||
llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
|
||||
CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());
|
||||
|
||||
@ -10244,7 +10269,7 @@ static void emitTargetCallKernelLaunch(
|
||||
|
||||
llvm::OpenMPIRBuilder::TargetKernelArgs Args(
|
||||
NumTargetItems, RTArgs, NumIterations, NumTeams, NumThreads,
|
||||
DynCGGroupMem, HasNoWait);
|
||||
DynCGroupMem, HasNoWait, DynCGroupMemFallback);
|
||||
|
||||
llvm::OpenMPIRBuilder::InsertPointTy AfterIP =
|
||||
cantFail(OMPRuntime->getOMPBuilder().emitKernelLaunch(
|
||||
|
||||
2633
clang/test/OpenMP/target_dyn_groupprivate_codegen.cpp
Normal file
2633
clang/test/OpenMP/target_dyn_groupprivate_codegen.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@ -190,6 +190,16 @@ enum class OMPScheduleType {
|
||||
LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue */ ModifierMask)
|
||||
};
|
||||
|
||||
/// The fallback types for the dyn_groupprivate clause.
|
||||
enum class OMPDynGroupprivateFallbackType : uint64_t {
|
||||
/// Abort the execution.
|
||||
Abort = 0,
|
||||
/// Return null pointer.
|
||||
Null = 1,
|
||||
/// Allocate from a implementation defined memory space.
|
||||
DefaultMem = 2
|
||||
};
|
||||
|
||||
// Default OpenMP mapper name suffix.
|
||||
inline constexpr const char *OmpDefaultMapperName = ".omp.default.mapper";
|
||||
|
||||
|
||||
@ -2446,20 +2446,24 @@ public:
|
||||
/// The number of threads.
|
||||
ArrayRef<Value *> NumThreads;
|
||||
/// The size of the dynamic shared memory.
|
||||
Value *DynCGGroupMem = nullptr;
|
||||
Value *DynCGroupMem = nullptr;
|
||||
/// True if the kernel has 'no wait' clause.
|
||||
bool HasNoWait = false;
|
||||
/// The fallback mechanism for the shared memory.
|
||||
omp::OMPDynGroupprivateFallbackType DynCGroupMemFallback =
|
||||
omp::OMPDynGroupprivateFallbackType::Abort;
|
||||
|
||||
// Constructors for TargetKernelArgs.
|
||||
TargetKernelArgs() = default;
|
||||
TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
|
||||
Value *NumIterations, ArrayRef<Value *> NumTeams,
|
||||
ArrayRef<Value *> NumThreads, Value *DynCGGroupMem,
|
||||
bool HasNoWait)
|
||||
ArrayRef<Value *> NumThreads, Value *DynCGroupMem,
|
||||
bool HasNoWait,
|
||||
omp::OMPDynGroupprivateFallbackType DynCGroupMemFallback)
|
||||
: NumTargetItems(NumTargetItems), RTArgs(RTArgs),
|
||||
NumIterations(NumIterations), NumTeams(NumTeams),
|
||||
NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
|
||||
HasNoWait(HasNoWait) {}
|
||||
NumThreads(NumThreads), DynCGroupMem(DynCGroupMem),
|
||||
HasNoWait(HasNoWait), DynCGroupMemFallback(DynCGroupMemFallback) {}
|
||||
};
|
||||
|
||||
/// Create the kernel args vector used by emitTargetKernel. This function
|
||||
@ -3244,6 +3248,10 @@ public:
|
||||
/// dependency information as passed in the depend clause
|
||||
/// \param HasNowait Whether the target construct has a `nowait` clause or
|
||||
/// not.
|
||||
/// \param DynCGroupMem The size of the dynamic groupprivate memory for each
|
||||
/// cgroup.
|
||||
/// \param DynCGroupMem The fallback mechanism to execute if the requested
|
||||
/// cgroup memory cannot be provided.
|
||||
LLVM_ABI InsertPointOrErrorTy createTarget(
|
||||
const LocationDescription &Loc, bool IsOffloadEntry,
|
||||
OpenMPIRBuilder::InsertPointTy AllocaIP,
|
||||
@ -3255,7 +3263,10 @@ public:
|
||||
TargetBodyGenCallbackTy BodyGenCB,
|
||||
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
|
||||
CustomMapperCallbackTy CustomMapperCB,
|
||||
const SmallVector<DependData> &Dependencies, bool HasNowait = false);
|
||||
const SmallVector<DependData> &Dependencies, bool HasNowait = false,
|
||||
Value *DynCGroupMem = nullptr,
|
||||
omp::OMPDynGroupprivateFallbackType DynCGroupMemFallback =
|
||||
omp::OMPDynGroupprivateFallbackType::Abort);
|
||||
|
||||
/// Returns __kmpc_for_static_init_* runtime function for the specified
|
||||
/// size \a IVSize and sign \a IVSigned. Will create a distribute call
|
||||
|
||||
@ -530,7 +530,13 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
|
||||
auto Int32Ty = Type::getInt32Ty(Builder.getContext());
|
||||
constexpr size_t MaxDim = 3;
|
||||
Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, MaxDim));
|
||||
Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
|
||||
|
||||
Value *HasNoWaitFlag = Builder.getInt64(KernelArgs.HasNoWait);
|
||||
|
||||
Value *DynCGroupMemFallbackFlag =
|
||||
Builder.getInt64(static_cast<uint64_t>(KernelArgs.DynCGroupMemFallback));
|
||||
DynCGroupMemFallbackFlag = Builder.CreateShl(DynCGroupMemFallbackFlag, 2);
|
||||
Value *Flags = Builder.CreateOr(HasNoWaitFlag, DynCGroupMemFallbackFlag);
|
||||
|
||||
assert(!KernelArgs.NumTeams.empty() && !KernelArgs.NumThreads.empty());
|
||||
|
||||
@ -559,7 +565,7 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
|
||||
Flags,
|
||||
NumTeams3D,
|
||||
NumThreads3D,
|
||||
KernelArgs.DynCGGroupMem};
|
||||
KernelArgs.DynCGroupMem};
|
||||
}
|
||||
|
||||
void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
|
||||
@ -8224,7 +8230,8 @@ static void emitTargetCall(
|
||||
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
|
||||
OpenMPIRBuilder::CustomMapperCallbackTy CustomMapperCB,
|
||||
const SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
|
||||
bool HasNoWait) {
|
||||
bool HasNoWait, Value *DynCGroupMem,
|
||||
OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
|
||||
// Generate a function call to the host fallback implementation of the target
|
||||
// region. This is called by the host when no offload entry was generated for
|
||||
// the target region and when the offloading call fails at runtime.
|
||||
@ -8360,12 +8367,13 @@ static void emitTargetCall(
|
||||
/*isSigned=*/false)
|
||||
: Builder.getInt64(0);
|
||||
|
||||
// TODO: Use correct DynCGGroupMem
|
||||
Value *DynCGGroupMem = Builder.getInt32(0);
|
||||
// Request zero groupprivate bytes by default.
|
||||
if (!DynCGroupMem)
|
||||
DynCGroupMem = Builder.getInt32(0);
|
||||
|
||||
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
|
||||
NumTeamsC, NumThreadsC,
|
||||
DynCGGroupMem, HasNoWait);
|
||||
KArgs = OpenMPIRBuilder::TargetKernelArgs(
|
||||
NumTargetItems, RTArgs, TripCount, NumTeamsC, NumThreadsC, DynCGroupMem,
|
||||
HasNoWait, DynCGroupMemFallback);
|
||||
|
||||
// Assume no error was returned because TaskBodyCB and
|
||||
// EmitTargetCallFallbackCB don't produce any.
|
||||
@ -8414,7 +8422,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
|
||||
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
|
||||
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
|
||||
CustomMapperCallbackTy CustomMapperCB,
|
||||
const SmallVector<DependData> &Dependencies, bool HasNowait) {
|
||||
const SmallVector<DependData> &Dependencies, bool HasNowait,
|
||||
Value *DynCGroupMem, OMPDynGroupprivateFallbackType DynCGroupMemFallback) {
|
||||
|
||||
if (!updateToLocation(Loc))
|
||||
return InsertPointTy();
|
||||
@ -8437,7 +8446,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
|
||||
if (!Config.isTargetDevice())
|
||||
emitTargetCall(*this, Builder, AllocaIP, Info, DefaultAttrs, RuntimeAttrs,
|
||||
IfCond, OutlinedFn, OutlinedFnID, Inputs, GenMapInfoCB,
|
||||
CustomMapperCB, Dependencies, HasNowait);
|
||||
CustomMapperCB, Dependencies, HasNowait, DynCGroupMem,
|
||||
DynCGroupMemFallback);
|
||||
return Builder.saveIP();
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user