[flang][OpenMP] Introduce WithReason<T> for nest/sequence properties (#187563)

This helper class contains an optional value and a "reason" message. It
replaces the uses of std::pair<optional<...>, Reason>.

Issue: https://github.com/llvm/llvm-project/issues/185287
This commit is contained in:
Krzysztof Parzyszek 2026-03-20 14:25:47 -05:00 committed by GitHub
parent 78b651a2cb
commit cfc94a6fd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 125 additions and 63 deletions

View File

@ -114,20 +114,45 @@ MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp);
/// A representation of a "because" message.
struct Reason {
Reason() = default;
Reason(Reason &&) = default;
Reason(const Reason &);
Reason &operator=(Reason &&) = default;
Reason &operator=(const Reason &);
parser::Messages msgs;
template <typename... Ts> Reason &Say(Ts &&...args) {
msgs.Say(std::forward<Ts>(args)...);
return *this;
}
operator bool() const { return !msgs.empty(); }
parser::Message &AttachTo(parser::Message &msg);
Reason &Append(const Reason &other) {
CopyFrom(other);
return *this;
}
operator bool() const { return !msgs.empty(); }
private:
void CopyFrom(const Reason &other);
};
std::pair<std::optional<int64_t>, Reason> GetArgumentValueWithReason(
// A property with an explanation of its value. Both, the property and the
// reason are optional (the reason can have no messages in it).
template <typename T> struct WithReason {
std::optional<T> value;
Reason reason;
WithReason() = default;
WithReason(std::optional<T> v, const Reason &r = Reason())
: value(v), reason(r) {}
operator bool() const { return value.has_value(); }
};
WithReason<int64_t> GetArgumentValueWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
unsigned version);
std::pair<std::optional<int64_t>, Reason> GetNumArgumentsWithReason(
WithReason<int64_t> GetNumArgumentsWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
unsigned version);
@ -135,20 +160,21 @@ bool IsLoopTransforming(llvm::omp::Directive dir);
bool IsFullUnroll(const parser::OpenMPLoopConstruct &x);
// Return the depth of the affected nests:
// {affected-depth, must-be-perfect-nest, reason}.
std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
// {affected-depth, reason, must-be-perfect-nest}.
std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
const parser::OmpDirectiveSpecification &spec, unsigned version);
// Return the range of the affected nests in the sequence:
// {first, count, reason}.
// If the range is "the whole sequence", the return value will be {1, -1, ...}.
std::tuple<std::optional<int64_t>, std::optional<int64_t>, Reason>
GetAffectedLoopRangeWithReason(
WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
const parser::OmpDirectiveSpecification &spec, unsigned version);
// Count the required loop count from range. If count == -1, return -1,
// indicating all loops in the sequence.
std::optional<int64_t> GetRequiredCount(
std::optional<int64_t> first, std::optional<int64_t> count);
std::optional<int64_t> GetRequiredCount(
std::optional<std::pair<int64_t, int64_t>> range);
struct LoopSequence {
LoopSequence(const parser::ExecutionPartConstruct &root, unsigned version,
@ -173,7 +199,7 @@ struct LoopSequence {
bool isNest() const { return length_ && *length_ == 1; }
std::optional<int64_t> length() const { return length_; }
Depth depth() const { return depth_; }
const Depth &depth() const { return depth_; }
const std::vector<LoopSequence> &children() const { return children_; }
private:

View File

@ -247,6 +247,7 @@ void OmpStructureChecker::CheckNestedConstruct(
const parser::OmpDirectiveSpecification &beginSpec{x.BeginDir()};
llvm::omp::Directive dir{beginSpec.DirId()};
unsigned version{context_.langOptions().OpenMPVersion};
parser::CharBlock beginSource{beginSpec.DirName().source};
// End-directive is not allowed in such cases:
// do 100 i = ...
@ -257,7 +258,6 @@ void OmpStructureChecker::CheckNestedConstruct(
auto &flags{std::get<parser::OmpDirectiveSpecification::Flags>(beginSpec.t)};
if (flags.test(parser::OmpDirectiveSpecification::Flag::CrossesLabelDo)) {
if (auto &endSpec{x.EndDir()}) {
parser::CharBlock beginSource{beginSpec.DirName().source};
context_
.Say(endSpec->DirName().source,
"END %s directive is not allowed when the construct does not contain all loops that share a loop-terminating statement"_err_en_US,
@ -292,30 +292,29 @@ void OmpStructureChecker::CheckNestedConstruct(
// Check if a loop-nest-associated construct has only one top-level loop
// in it.
auto [needFirst, needCount, rangeReason]{
GetAffectedLoopRangeWithReason(beginSpec, version)};
auto needRange{GetAffectedLoopRangeWithReason(beginSpec, version)};
if (std::optional<int64_t> numLoops{sequence.length()}) {
if (*numLoops == 0) {
context_.Say(beginSpec.DirName().source,
context_.Say(beginSource,
"This construct should contain a DO-loop or a loop-nest-generating OpenMP construct"_err_en_US);
} else {
auto assoc{llvm::omp::getDirectiveAssociation(dir)};
if (*numLoops > 1 && assoc == llvm::omp::Association::LoopNest) {
context_.Say(beginSpec.DirName().source,
context_.Say(beginSource,
"This construct applies to a loop nest, but has a loop sequence of "
"length %" PRId64 ""_err_en_US,
*numLoops);
}
if (assoc == llvm::omp::Association::LoopSeq) {
if (auto requiredCount{GetRequiredCount(needFirst, needCount)}) {
if (auto requiredCount{GetRequiredCount(needRange.value)}) {
if (*requiredCount > 0 && *numLoops < *requiredCount) {
auto &msg{context_.Say(beginSpec.DirName().source,
auto &msg{context_.Say(beginSource,
"This construct requires a sequence of %" PRId64
" loops, but the loop sequence has a length of %" PRId64
""_err_en_US,
*requiredCount, *numLoops)};
rangeReason.AttachTo(msg);
needRange.reason.AttachTo(msg);
}
}
}
@ -323,29 +322,29 @@ void OmpStructureChecker::CheckNestedConstruct(
}
// Check requirements on nest depth.
auto [needDepth, needPerfect, depthReason]{
auto [needDepth, needPerfect]{
GetAffectedNestDepthWithReason(beginSpec, version)};
auto [haveSema, havePerf]{sequence.depth()};
auto &[haveSema, havePerf]{sequence.depth()};
if (dir != llvm::omp::Directive::OMPD_fuse) {
auto haveDepth = needPerfect ? havePerf : haveSema;
auto &haveDepth = needPerfect ? havePerf : haveSema;
// If the present depth is 0, it's likely that the construct doesn't
// have any loops in it, which would be diagnosed above.
if (needDepth && haveDepth && *haveDepth > 0) {
if (*needDepth > *haveDepth) {
if (needDepth && haveDepth > 0) {
if (*needDepth.value > *haveDepth) {
if (needPerfect) {
auto &msg{context_.Say(beginSpec.DirName().source,
auto &msg{context_.Say(beginSource,
"This construct requires a perfect nest of depth %" PRId64
", but the associated nest is a perfect nest of depth %" PRId64
""_err_en_US,
*needDepth, *haveDepth)};
depthReason.AttachTo(msg);
*needDepth.value, *haveDepth)};
needDepth.reason.AttachTo(msg);
} else {
auto &msg{context_.Say(beginSpec.DirName().source,
auto &msg{context_.Say(beginSource,
"This construct requires a nest of depth %" PRId64
", but the associated nest has a depth of %" PRId64 ""_err_en_US,
*needDepth, *haveDepth)};
depthReason.AttachTo(msg);
*needDepth.value, *haveDepth)};
needDepth.reason.AttachTo(msg);
}
}
}

View File

@ -532,12 +532,30 @@ MaybeExpr MakeEvaluateExpr(const parser::OmpStylizedInstance &inp) {
instance.u);
}
Reason::Reason(const Reason &other) { //
CopyFrom(other);
}
Reason &Reason::operator=(const Reason &other) {
if (this != &other) {
msgs.clear();
CopyFrom(other);
}
return *this;
}
void Reason::CopyFrom(const Reason &other) {
for (auto &msg : other.msgs.messages()) {
msgs.Say(parser::Message(msg));
}
}
parser::Message &Reason::AttachTo(parser::Message &msg) {
msgs.AttachTo(msg);
return msg;
}
std::pair<std::optional<int64_t>, Reason> GetArgumentValueWithReason(
WithReason<int64_t> GetArgumentValueWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
unsigned version) {
if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
@ -552,12 +570,11 @@ std::pair<std::optional<int64_t>, Reason> GetArgumentValueWithReason(
}
}
}
return {std::nullopt, Reason()};
return {};
}
template <typename T>
static std::pair<std::optional<int64_t>, Reason>
GetNumArgumentsWithReasonForType(
static WithReason<int64_t> GetNumArgumentsWithReasonForType(
const parser::OmpClause &clause, const std::string &name) {
if (auto *args{parser::Unwrap<std::list<T>>(clause.u)}) {
auto num{static_cast<int64_t>(args->size())};
@ -567,10 +584,10 @@ GetNumArgumentsWithReasonForType(
name.c_str(), num);
return {num, std::move(reason)};
}
return {std::nullopt, Reason()};
return {};
}
std::pair<std::optional<int64_t>, Reason> GetNumArgumentsWithReason(
WithReason<int64_t> GetNumArgumentsWithReason(
const parser::OmpDirectiveSpecification &spec, llvm::omp::Clause clauseId,
unsigned version) {
if (auto *clause{parser::omp::FindClause(spec, clauseId)}) {
@ -578,22 +595,18 @@ std::pair<std::optional<int64_t>, Reason> GetNumArgumentsWithReason(
// Try the types used for list items.
{
using Ty = parser::ScalarIntExpr;
if (auto [num, reason]{
GetNumArgumentsWithReasonForType<Ty>(*clause, name)};
num) {
return {num, std::move(reason)};
if (auto n{GetNumArgumentsWithReasonForType<Ty>(*clause, name)}) {
return n;
}
}
{
using Ty = parser::ScalarIntConstantExpr;
if (auto [num, reason]{
GetNumArgumentsWithReasonForType<Ty>(*clause, name)};
num) {
return {num, std::move(reason)};
if (auto n{GetNumArgumentsWithReasonForType<Ty>(*clause, name)}) {
return n;
}
}
}
return {std::nullopt, Reason()};
return {};
}
bool IsLoopTransforming(llvm::omp::Directive dir) {
@ -791,9 +804,25 @@ bool IsTransformableLoop(const parser::ExecutionPartConstruct &epc) {
return false;
}
template <typename T,
typename = std::enable_if_t<std::is_arithmetic_v<llvm::remove_cvref_t<T>>>>
WithReason<T> operator+(const WithReason<T> &a, const WithReason<T> &b) {
if (a.value && b.value) {
return WithReason<T>{
*a.value + *b.value, Reason().Append(a.reason).Append(b.reason)};
}
return WithReason<T>();
}
template <typename T,
typename = std::enable_if_t<std::is_arithmetic_v<llvm::remove_cvref_t<T>>>>
WithReason<T> operator+(T a, const WithReason<T> &b) {
return WithReason<T>{a, Reason()} + b;
}
// Return the depth of the affected nests:
// {affected-depth, must-be-perfect-nest}.
std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
std::pair<WithReason<int64_t>, bool> GetAffectedNestDepthWithReason(
const parser::OmpDirectiveSpecification &spec, unsigned version) {
llvm::omp::Directive dir{spec.DirId()};
bool allowsCollapse{llvm::omp::isAllowedClauseForDirective(
@ -812,7 +841,7 @@ std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
reason = std::move(ro);
}
}
return {count, true, std::move(reason)};
return {{count, std::move(reason)}, true};
}
if (IsLoopTransforming(dir)) {
@ -822,7 +851,7 @@ std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_permutation)) {
auto [num, reason]{GetNumArgumentsWithReason(
spec, llvm::omp::Clause::OMPC_permutation, version)};
return {num, true, std::move(reason)};
return {{num, std::move(reason)}, true};
}
// PERMUTATION not specified, assume PERMUTATION(2, 1).
std::string name{parser::omp::GetUpperName(
@ -831,21 +860,21 @@ std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
reason.Say(spec.source,
"%s clause was not specified, %s(2, 1) was assumed"_because_en_US,
name.c_str(), name.c_str());
return {2, true, std::move(reason)};
return {{2, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_stripe:
case llvm::omp::Directive::OMPD_tile: {
// Get the length of the argument list to SIZES.
auto [num, reason]{GetNumArgumentsWithReason(
spec, llvm::omp::Clause::OMPC_sizes, version)};
return {num, true, std::move(reason)};
return {{num, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_fuse: {
// Get the value from the argument to DEPTH.
if (parser::omp::FindClause(spec, llvm::omp::Clause::OMPC_depth)) {
auto [count, reason]{GetArgumentValueWithReason(
spec, llvm::omp::Clause::OMPC_depth, version)};
return {count, true, std::move(reason)};
return {{count, std::move(reason)}, true};
}
std::string name{
parser::omp::GetUpperName(llvm::omp::Clause::OMPC_depth, version)};
@ -853,11 +882,11 @@ std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
reason.Say(spec.source,
"%s clause was not specified, a value of 1 was assumed"_because_en_US,
name.c_str());
return {1, true, std::move(reason)};
return {{1, std::move(reason)}, true};
}
case llvm::omp::Directive::OMPD_reverse:
case llvm::omp::Directive::OMPD_unroll:
return {1, false, Reason()};
return {WithReason<int64_t>(1), false};
// TODO: case llvm::omp::Directive::OMPD_flatten:
// TODO: case llvm::omp::Directive::OMPD_split:
default:
@ -865,13 +894,12 @@ std::tuple<std::optional<int64_t>, bool, Reason> GetAffectedNestDepthWithReason(
}
}
return {std::nullopt, false, Reason()};
return {{}, false};
}
// Return the range of the affected nests in the sequence:
// {first, count, std::move(reason)}.
std::tuple<std::optional<int64_t>, std::optional<int64_t>, Reason>
GetAffectedLoopRangeWithReason(
WithReason<std::pair<int64_t, int64_t>> GetAffectedLoopRangeWithReason(
const parser::OmpDirectiveSpecification &spec, unsigned version) {
llvm::omp::Directive dir{spec.DirId()};
@ -883,7 +911,7 @@ GetAffectedLoopRangeWithReason(
std::optional<int64_t> first{GetIntValue(std::get<0>(range.t))};
std::optional<int64_t> count{GetIntValue(std::get<1>(range.t))};
if (!first || !count || *first <= 0 || *count <= 0) {
return {std::nullopt, std::nullopt, Reason()};
return {};
}
std::string name{parser::omp::GetUpperName(
llvm::omp::Clause::OMPC_looprange, version)};
@ -892,7 +920,7 @@ GetAffectedLoopRangeWithReason(
"%s clause was specified with a count of %" PRId64
" starting at loop %" PRId64 ""_because_en_US,
name.c_str(), *count, *first);
return {*first, *count, std::move(reason)};
return {std::make_pair(*first, *count), std::move(reason)};
}
// If LOOPRANGE was not found, return {1, -1}, where -1 means "the whole
// associated sequence".
@ -900,14 +928,14 @@ GetAffectedLoopRangeWithReason(
reason.Say(spec.source,
"%s clause was not specified, a value of 1 was assumed"_because_en_US,
name.c_str());
return {1, -1, std::move(reason)};
return {std::make_pair(1, -1), std::move(reason)};
}
assert(llvm::omp::getDirectiveAssociation(dir) ==
llvm::omp::Association::LoopNest &&
"Expecting loop-nest-associated construct");
// For loop-nest constructs, a single loop-nest is affected.
return {1, 1, Reason()};
return {std::make_pair(1, 1), Reason()};
}
std::optional<int64_t> GetRequiredCount(
@ -922,6 +950,14 @@ std::optional<int64_t> GetRequiredCount(
return std::nullopt;
}
std::optional<int64_t> GetRequiredCount(
std::optional<std::pair<int64_t, int64_t>> range) {
if (range) {
return GetRequiredCount(range->first, range->second);
}
return GetRequiredCount(std::nullopt, std::nullopt);
}
#ifdef EXPENSIVE_CHECKS
namespace {
/// Check that for every value x of type T, there will be a "source" member
@ -1174,9 +1210,8 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
// The result is a perfect nest only if all loop in the sequence
// are fused.
if (value && nestedLength) {
auto [first, count, _]{
GetAffectedLoopRangeWithReason(beginSpec, version_)};
if (auto required{GetRequiredCount(first, count)}) {
auto range{GetAffectedLoopRangeWithReason(beginSpec, version_)};
if (auto required{GetRequiredCount(range.value)}) {
if (*required == -1 || *required == *nestedLength) {
return Depth{value, value};
}
@ -1185,6 +1220,7 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
}
return Depth{std::nullopt, std::nullopt};
}
// FUSE cannot create a nest of depth > 1 without DEPTH clause.
return Depth{1, 1};
case llvm::omp::Directive::OMPD_interchange:
case llvm::omp::Directive::OMPD_nothing:
@ -1201,7 +1237,7 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
return Depth{plus(num, semaDepth), plus(num, perfDepth)};
}
// The SIZES clause is mandatory, if it's missing the result is unknown.
return {std::nullopt, std::nullopt};
return {};
case llvm::omp::Directive::OMPD_unroll:
if (IsFullUnroll(omp)) {
return Depth{0, 0};
@ -1231,9 +1267,10 @@ LoopSequence::Depth LoopSequence::calculateDepths() const {
}
LoopSequence::Depth LoopSequence::getNestedDepths() const {
if (length() != 1) {
return Depth{0, 0};
if (!isNest()) {
return {std::nullopt, std::nullopt};
} else if (children_.empty()) {
// No children, but length == 1.
assert(entry_->owner &&
parser::Unwrap<parser::DoConstruct>(entry_->owner) &&
"Expecting DO construct");