diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td index ce0ebabd5812..deb2fba1cd79 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td @@ -223,19 +223,19 @@ def ScheduleModifier : OpenMP_I32EnumAttr< def ScheduleModifierAttr : OpenMP_EnumAttr; //===----------------------------------------------------------------------===// -// target_region_flags enum. +// target_exec_mode enum. //===----------------------------------------------------------------------===// -def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">; -def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 0>; -def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 1>; +def TargetExecModeBare : I32EnumAttrCase<"bare", 0>; +def TargetExecModeGeneric : I32EnumAttrCase<"generic", 1>; +def TargetExecModeSpmd : I32EnumAttrCase<"spmd", 2>; -def TargetRegionFlags : OpenMP_BitEnumAttr< - "TargetRegionFlags", - "target region property flags", [ - TargetRegionFlagsNone, - TargetRegionFlagsSpmd, - TargetRegionFlagsTripCount +def TargetExecMode : OpenMP_I32EnumAttr< + "TargetExecMode", + "target execution mode, mirroring the `OMPTgtExecModeFlags` LLVM enum", [ + TargetExecModeBare, + TargetExecModeGeneric, + TargetExecModeSpmd, ]>; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index be114ea4fb63..6569905c5fae 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1517,13 +1517,17 @@ def TargetOp : OpenMP_Op<"target", traits = [ /// operations, the top level one will be the one captured. Operation *getInnermostCapturedOmpOp(); - /// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the - /// contents of the target region. + /// Infers the kernel type (Bare, Generic or SPMD) based on the contents of + /// the target region. /// /// \param capturedOp result of a still valid (no modifications made to any /// nested operations) previous call to `getInnermostCapturedOmpOp()`. - static ::mlir::omp::TargetRegionFlags - getKernelExecFlags(Operation *capturedOp); + /// \param hostEvalTripCount output argument to store whether this kernel + /// wraps a loop whose bounds must be evaluated on the host prior to + /// launching it. + static ::mlir::omp::TargetExecMode + getKernelExecFlags(Operation *capturedOp, + bool *hostEvalTripCount = nullptr); }] # clausesExtraClassDeclaration; let assemblyFormat = clausesAssemblyFormat # [{ diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 8854e908c71f..c3c17006fe57 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1974,8 +1974,9 @@ LogicalResult TargetOp::verifyRegions() { return emitError("target containing multiple 'omp.teams' nested ops"); // Check that host_eval values are only used in legal ways. + bool hostEvalTripCount; Operation *capturedOp = getInnermostCapturedOmpOp(); - TargetRegionFlags execFlags = getKernelExecFlags(capturedOp); + TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount); for (Value hostEvalArg : cast(getOperation()).getHostEvalBlockArgs()) { for (Operation *user : hostEvalArg.getUsers()) { @@ -1990,7 +1991,7 @@ LogicalResult TargetOp::verifyRegions() { "and 'thread_limit' in 'omp.teams'"; } if (auto parallelOp = dyn_cast(user)) { - if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && + if (execMode == TargetExecMode::spmd && parallelOp->isAncestor(capturedOp) && hostEvalArg == parallelOp.getNumThreads()) continue; @@ -2000,8 +2001,7 @@ LogicalResult TargetOp::verifyRegions() { "'omp.parallel' when representing target SPMD"; } if (auto loopNestOp = dyn_cast(user)) { - if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) && - loopNestOp.getOperation() == capturedOp && + if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp && (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) @@ -2106,7 +2106,9 @@ Operation *TargetOp::getInnermostCapturedOmpOp() { }); } -TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { +TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp, + bool *hostEvalTripCount) { + // TODO: Support detection of bare kernel mode. // A non-null captured op is only valid if it resides inside of a TargetOp // and is the result of calling getInnermostCapturedOmpOp() on it. TargetOp targetOp = @@ -2115,9 +2117,12 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) && "unexpected captured op"); + if (hostEvalTripCount) + *hostEvalTripCount = false; + // If it's not capturing a loop, it's a default target region. if (!isa_and_present(capturedOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; // Get the innermost non-simd loop wrapper. SmallVector loopWrappers; @@ -2130,53 +2135,59 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { auto numWrappers = std::distance(innermostWrapper, loopWrappers.end()); if (numWrappers != 1 && numWrappers != 2) - return TargetRegionFlags::none; + return TargetExecMode::generic; // Detect target-teams-distribute-parallel-wsloop[-simd]. if (numWrappers == 2) { if (!isa(innermostWrapper)) - return TargetRegionFlags::none; + return TargetExecMode::generic; innermostWrapper = std::next(innermostWrapper); if (!isa(innermostWrapper)) - return TargetRegionFlags::none; + return TargetExecMode::generic; Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; Operation *teamsOp = parallelOp->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; - if (teamsOp->getParentOp() == targetOp.getOperation()) - return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + if (teamsOp->getParentOp() == targetOp.getOperation()) { + if (hostEvalTripCount) + *hostEvalTripCount = true; + return TargetExecMode::spmd; + } } // Detect target-teams-distribute[-simd] and target-teams-loop. else if (isa(innermostWrapper)) { Operation *teamsOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(teamsOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; if (teamsOp->getParentOp() != targetOp.getOperation()) - return TargetRegionFlags::none; + return TargetExecMode::generic; + + if (hostEvalTripCount) + *hostEvalTripCount = true; if (isa(innermostWrapper)) - return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; + return TargetExecMode::spmd; - return TargetRegionFlags::trip_count; + return TargetExecMode::generic; } // Detect target-parallel-wsloop[-simd]. else if (isa(innermostWrapper)) { Operation *parallelOp = (*innermostWrapper)->getParentOp(); if (!isa_and_present(parallelOp)) - return TargetRegionFlags::none; + return TargetExecMode::generic; if (parallelOp->getParentOp() == targetOp.getOperation()) - return TargetRegionFlags::spmd; + return TargetExecMode::spmd; } - return TargetRegionFlags::none; + return TargetExecMode::generic; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 88601ef45911..d49cc38cd792 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -5354,11 +5354,18 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp, } // Update kernel bounds structure for the `OpenMPIRBuilder` to use. - omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp); - attrs.ExecFlags = - omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) - ? llvm::omp::OMP_TGT_EXEC_MODE_SPMD - : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC; + omp::TargetExecMode execMode = targetOp.getKernelExecFlags(capturedOp); + switch (execMode) { + case omp::TargetExecMode::bare: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_BARE; + break; + case omp::TargetExecMode::generic: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_GENERIC; + break; + case omp::TargetExecMode::spmd: + attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD; + break; + } attrs.MinTeams = minTeamsVal; attrs.MaxTeams.front() = maxTeamsVal; attrs.MinThreads = 1; @@ -5408,8 +5415,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder, if (numThreads) attrs.MaxThreads = moduleTranslation.lookupValue(numThreads); - if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp), - omp::TargetRegionFlags::trip_count)) { + bool hostEvalTripCount; + targetOp.getKernelExecFlags(capturedOp, &hostEvalTripCount); + if (hostEvalTripCount) { llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); attrs.LoopTripCount = nullptr;