Update TargetRegionFlags to mirror OMPTgtExecModeFlags

This commit is contained in:
Sergio Afonso 2025-08-13 12:53:34 +01:00
parent 39800face1
commit 9e948a58af
4 changed files with 64 additions and 41 deletions

View File

@ -223,19 +223,19 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
//===----------------------------------------------------------------------===//
// target_region_flags enum.
// target_exec_mode enum.
//===----------------------------------------------------------------------===//
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
def TargetExecModeBare : I32EnumAttrCase<"bare", 0>;
def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>;
def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>;
def TargetRegionFlags : OpenMP_BitEnumAttr<
"TargetRegionFlags",
"target region property flags", [
TargetRegionFlagsNone,
TargetRegionFlagsSpmd,
TargetRegionFlagsTripCount
def TargetExecMode : OpenMP_I32EnumAttr<
"TargetExecMode",
"target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [
TargetExecModeBare,
TargetExecModeGeneric,
TargetExecModeSpmd,
]>;
//===----------------------------------------------------------------------===//

View File

@ -1517,13 +1517,17 @@ def TargetOp : OpenMP_Op<"target", traits = [
/// operations, the top level one will be the one captured.
Operation *getInnermostCapturedOmpOp();
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
/// contents of the target region.
/// Infers the kernel type (Bare, Generic or SPMD) based on the contents of
/// the target region.
///
/// \param capturedOp result of a still valid (no modifications made to any
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
static ::mlir::omp::TargetRegionFlags
getKernelExecFlags(Operation *capturedOp);
/// \param hostEvalTripCount output argument to store whether this kernel
/// wraps a loop whose bounds must be evaluated on the host prior to
/// launching it.
static ::mlir::omp::TargetExecMode
getKernelExecFlags(Operation *capturedOp,
bool *hostEvalTripCount = nullptr);
}] # clausesExtraClassDeclaration;
let assemblyFormat = clausesAssemblyFormat # [{

View File

@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() {
return emitError("target containing multiple 'omp.teams' nested ops");
// Check that host_eval values are only used in legal ways.
bool hostEvalTripCount;
Operation *capturedOp = getInnermostCapturedOmpOp();
TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() {
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
if (execMode == TargetExecMode::spmd &&
parallelOp->isAncestor(capturedOp) &&
hostEvalArg == parallelOp.getNumThreads())
continue;
@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() {
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
loopNestOp.getOperation() == capturedOp &&
if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
});
}
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
bool *hostEvalTripCount) {
// TODO: Support detection of bare kernel mode.
// A non-null captured op is only valid if it resides inside of a TargetOp
// and is the result of calling getInnermostCapturedOmpOp() on it.
TargetOp targetOp =
@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
"unexpected captured op");
if (hostEvalTripCount)
*hostEvalTripCount = false;
// If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
return TargetRegionFlags::none;
return TargetExecMode::generic;
// Get the innermost non-simd loop wrapper.
SmallVector<LoopWrapperInterface> loopWrappers;
@ -2130,53 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
if (numWrappers != 1 && numWrappers != 2)
return TargetRegionFlags::none;
return TargetExecMode::generic;
// Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
return TargetRegionFlags::none;
return TargetExecMode::generic;
innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
return TargetRegionFlags::none;
return TargetExecMode::generic;
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::none;
return TargetExecMode::generic;
Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return TargetRegionFlags::none;
return TargetExecMode::generic;
if (teamsOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
if (teamsOp->getParentOp() == targetOp.getOperation()) {
if (hostEvalTripCount)
*hostEvalTripCount = true;
return TargetExecMode::spmd;
}
}
// Detect target-teams-distribute[-simd] and target-teams-loop.
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return TargetRegionFlags::none;
return TargetExecMode::generic;
if (teamsOp->getParentOp() != targetOp.getOperation())
return TargetRegionFlags::none;
return TargetExecMode::generic;
if (hostEvalTripCount)
*hostEvalTripCount = true;
if (isa<LoopOp>(innermostWrapper))
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
return TargetExecMode::spmd;
return TargetRegionFlags::trip_count;
return TargetExecMode::generic;
}
// Detect target-parallel-wsloop[-simd].
else if (isa<WsloopOp>(innermostWrapper)) {
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::none;
return TargetExecMode::generic;
if (parallelOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd;
return TargetExecMode::spmd;
}
return TargetRegionFlags::none;
return TargetExecMode::generic;
}
//===----------------------------------------------------------------------===//

View File

@ -5354,11 +5354,18 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
}
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
attrs.ExecFlags =
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp);
switch (execMode) {
case omp::TargetExecMode::bare:
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE;
break;
case omp::TargetExecMode::generic:
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
break;
case omp::TargetExecMode::spmd:
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
break;
}
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
@ -5408,8 +5415,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
omp::TargetRegionFlags::trip_count)) {
bool hostEvalTripCount;
targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount);
if (hostEvalTripCount) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;