Update TargetRegionFlags to mirror OMPTgtExecModeFlags
This commit is contained in:
parent
39800face1
commit
9e948a58af
@ -223,19 +223,19 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
|
||||
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// target_region_flags enum.
|
||||
// target_exec_mode enum.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
|
||||
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
|
||||
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
|
||||
def TargetExecModeBare : I32EnumAttrCase<"bare", 0>;
|
||||
def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>;
|
||||
def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>;
|
||||
|
||||
def TargetRegionFlags : OpenMP_BitEnumAttr<
|
||||
"TargetRegionFlags",
|
||||
"target region property flags", [
|
||||
TargetRegionFlagsNone,
|
||||
TargetRegionFlagsSpmd,
|
||||
TargetRegionFlagsTripCount
|
||||
def TargetExecMode : OpenMP_I32EnumAttr<
|
||||
"TargetExecMode",
|
||||
"target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [
|
||||
TargetExecModeBare,
|
||||
TargetExecModeGeneric,
|
||||
TargetExecModeSpmd,
|
||||
]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1517,13 +1517,17 @@ def TargetOp : OpenMP_Op<"target", traits = [
|
||||
/// operations, the top level one will be the one captured.
|
||||
Operation *getInnermostCapturedOmpOp();
|
||||
|
||||
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
|
||||
/// contents of the target region.
|
||||
/// Infers the kernel type (Bare, Generic or SPMD) based on the contents of
|
||||
/// the target region.
|
||||
///
|
||||
/// \param capturedOp result of a still valid (no modifications made to any
|
||||
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
|
||||
static ::mlir::omp::TargetRegionFlags
|
||||
getKernelExecFlags(Operation *capturedOp);
|
||||
/// \param hostEvalTripCount output argument to store whether this kernel
|
||||
/// wraps a loop whose bounds must be evaluated on the host prior to
|
||||
/// launching it.
|
||||
static ::mlir::omp::TargetExecMode
|
||||
getKernelExecFlags(Operation *capturedOp,
|
||||
bool *hostEvalTripCount = nullptr);
|
||||
}] # clausesExtraClassDeclaration;
|
||||
|
||||
let assemblyFormat = clausesAssemblyFormat # [{
|
||||
|
@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() {
|
||||
return emitError("target containing multiple 'omp.teams' nested ops");
|
||||
|
||||
// Check that host_eval values are only used in legal ways.
|
||||
bool hostEvalTripCount;
|
||||
Operation *capturedOp = getInnermostCapturedOmpOp();
|
||||
TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
|
||||
TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
|
||||
for (Value hostEvalArg :
|
||||
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
|
||||
for (Operation *user : hostEvalArg.getUsers()) {
|
||||
@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() {
|
||||
"and 'thread_limit' in 'omp.teams'";
|
||||
}
|
||||
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
|
||||
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
|
||||
if (execMode == TargetExecMode::spmd &&
|
||||
parallelOp->isAncestor(capturedOp) &&
|
||||
hostEvalArg == parallelOp.getNumThreads())
|
||||
continue;
|
||||
@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() {
|
||||
"'omp.parallel' when representing target SPMD";
|
||||
}
|
||||
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
|
||||
if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
|
||||
loopNestOp.getOperation() == capturedOp &&
|
||||
if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
|
||||
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
|
||||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
|
||||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
|
||||
@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
|
||||
});
|
||||
}
|
||||
|
||||
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
|
||||
TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
|
||||
bool *hostEvalTripCount) {
|
||||
// TODO: Support detection of bare kernel mode.
|
||||
// A non-null captured op is only valid if it resides inside of a TargetOp
|
||||
// and is the result of calling getInnermostCapturedOmpOp() on it.
|
||||
TargetOp targetOp =
|
||||
@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
|
||||
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
|
||||
"unexpected captured op");
|
||||
|
||||
if (hostEvalTripCount)
|
||||
*hostEvalTripCount = false;
|
||||
|
||||
// If it's not capturing a loop, it's a default target region.
|
||||
if (!isa_and_present<LoopNestOp>(capturedOp))
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
// Get the innermost non-simd loop wrapper.
|
||||
SmallVector<LoopWrapperInterface> loopWrappers;
|
||||
@ -2130,53 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
|
||||
|
||||
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
|
||||
if (numWrappers != 1 && numWrappers != 2)
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
// Detect target-teams-distribute-parallel-wsloop[-simd].
|
||||
if (numWrappers == 2) {
|
||||
if (!isa<WsloopOp>(innermostWrapper))
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
innermostWrapper = std::next(innermostWrapper);
|
||||
if (!isa<DistributeOp>(innermostWrapper))
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
Operation *parallelOp = (*innermostWrapper)->getParentOp();
|
||||
if (!isa_and_present<ParallelOp>(parallelOp))
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
Operation *teamsOp = parallelOp->getParentOp();
|
||||
if (!isa_and_present<TeamsOp>(teamsOp))
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
if (teamsOp->getParentOp() == targetOp.getOperation())
|
||||
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
|
||||
if (teamsOp->getParentOp() == targetOp.getOperation()) {
|
||||
if (hostEvalTripCount)
|
||||
*hostEvalTripCount = true;
|
||||
return TargetExecMode::spmd;
|
||||
}
|
||||
}
|
||||
// Detect target-teams-distribute[-simd] and target-teams-loop.
|
||||
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
|
||||
Operation *teamsOp = (*innermostWrapper)->getParentOp();
|
||||
if (!isa_and_present<TeamsOp>(teamsOp))
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
if (teamsOp->getParentOp() != targetOp.getOperation())
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
if (hostEvalTripCount)
|
||||
*hostEvalTripCount = true;
|
||||
|
||||
if (isa<LoopOp>(innermostWrapper))
|
||||
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
|
||||
return TargetExecMode::spmd;
|
||||
|
||||
return TargetRegionFlags::trip_count;
|
||||
return TargetExecMode::generic;
|
||||
}
|
||||
// Detect target-parallel-wsloop[-simd].
|
||||
else if (isa<WsloopOp>(innermostWrapper)) {
|
||||
Operation *parallelOp = (*innermostWrapper)->getParentOp();
|
||||
if (!isa_and_present<ParallelOp>(parallelOp))
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
|
||||
if (parallelOp->getParentOp() == targetOp.getOperation())
|
||||
return TargetRegionFlags::spmd;
|
||||
return TargetExecMode::spmd;
|
||||
}
|
||||
|
||||
return TargetRegionFlags::none;
|
||||
return TargetExecMode::generic;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -5354,11 +5354,18 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
|
||||
}
|
||||
|
||||
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
|
||||
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
|
||||
attrs.ExecFlags =
|
||||
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
|
||||
? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
|
||||
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
|
||||
omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
|
||||
switch (execMode) {
|
||||
case omp::TargetExecMode::bare:
|
||||
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
|
||||
break;
|
||||
case omp::TargetExecMode::generic:
|
||||
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
|
||||
break;
|
||||
case omp::TargetExecMode::spmd:
|
||||
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
|
||||
break;
|
||||
}
|
||||
attrs.MinTeams = minTeamsVal;
|
||||
attrs.MaxTeams.front() = maxTeamsVal;
|
||||
attrs.MinThreads = 1;
|
||||
@ -5408,8 +5415,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
|
||||
if (numThreads)
|
||||
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
|
||||
|
||||
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
|
||||
omp::TargetRegionFlags::trip_count)) {
|
||||
bool hostEvalTripCount;
|
||||
targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
|
||||
if (hostEvalTripCount) {
|
||||
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
|
||||
attrs.LoopTripCount = nullptr;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user