[MLIR][OpenMP] Remove Generic-SPMD early detection

This patch removes logic from MLIR to attempt identifying Generic kernels that
could be executed in SPMD mode.

This optimization is done by the OpenMPOpt pass for Clang and is only required
here to circumvent missing support for the new DeviceRTL APIs used in MLIR to
LLVM IR translation that Clang doesn't currently use (e.g.
`kmpc_distribute_static_loop` ). Removing checks in MLIR avoids duplicating the
logic that should be centralized in the OpenMPOpt pass.

Additionally, offloading kernels currently compiled through the OpenMP dialect
fail to run parallel regions properly when in Generic mode. By disabling early
detection, this issue becomes apparent for a range of kernels where this was
masked by having them run in SPMD mode.
This commit is contained in:
Sergio Afonso 2025-05-21 16:29:23 +01:00
parent b09b05a83e
commit 39800face1
4 changed files with 17 additions and 51 deletions

View File

@ -227,15 +227,13 @@ def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
//===----------------------------------------------------------------------===//
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>;
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>;
def TargetRegionFlags : OpenMP_BitEnumAttr<
"TargetRegionFlags",
"target region property flags", [
TargetRegionFlagsNone,
TargetRegionFlagsGeneric,
TargetRegionFlagsSpmd,
TargetRegionFlagsTripCount
]>;

View File

@ -2117,7 +2117,7 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
// If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
// Get the innermost non-simd loop wrapper.
SmallVector<LoopWrapperInterface> loopWrappers;
@ -2130,24 +2130,24 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
if (numWrappers != 1 && numWrappers != 2)
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
// Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
if (teamsOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
@ -2156,53 +2156,27 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
if (teamsOp->getParentOp() != targetOp.getOperation())
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
if (isa<LoopOp>(innermostWrapper))
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
// Find single immediately nested captured omp.parallel and add spmd flag
// (generic-spmd case).
//
// TODO: This shouldn't have to be done here, as it is too easy to break.
// The openmp-opt pass should be updated to be able to promote kernels like
// this from "Generic" to "Generic-SPMD". However, the use of the
// `kmpc_distribute_static_loop` family of functions produced by the
// OMPIRBuilder for these kernels prevents that from working.
Dialect *ompDialect = targetOp->getDialect();
Operation *nestedCapture = findCapturedOmpOp(
capturedOp, /*checkSingleMandatoryExec=*/false,
[&](Operation *sibling) {
return sibling && (ompDialect != sibling->getDialect() ||
sibling->hasTrait<OpTrait::IsTerminator>());
});
TargetRegionFlags result =
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
if (!nestedCapture)
return result;
while (nestedCapture->getParentOp() != capturedOp)
nestedCapture = nestedCapture->getParentOp();
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
: result;
return TargetRegionFlags::trip_count;
}
// Detect target-parallel-wsloop[-simd].
else if (isa<WsloopOp>(innermostWrapper)) {
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
if (parallelOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd;
}
return TargetRegionFlags::generic;
return TargetRegionFlags::none;
}
//===----------------------------------------------------------------------===//

View File

@ -5355,16 +5355,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
assert(
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
omp::TargetRegionFlags::spmd) &&
"invalid kernel flags");
attrs.ExecFlags =
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
? llvm::omp::OMP_TGT_EXEC_MODE_SPMD
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC;
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;

View File

@ -87,7 +87,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
}
}
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:3]]
// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 [[EXEC_MODE:1]]
// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE]], {{.*}}},