diff --git a/flang/include/flang/Semantics/openmp-utils.h b/flang/include/flang/Semantics/openmp-utils.h index b8ad9ed17c72..1c54124a5738 100644 --- a/flang/include/flang/Semantics/openmp-utils.h +++ b/flang/include/flang/Semantics/openmp-utils.h @@ -22,6 +22,8 @@ #include #include +#include +#include namespace Fortran::semantics { class SemanticsContext; @@ -29,6 +31,12 @@ class Symbol; // Add this namespace to avoid potential conflicts namespace omp { +template > U AsRvalue(T &t) { + return U(t); +} + +template T &&AsRvalue(T &&t) { return std::move(t); } + // There is no consistent way to get the source of an ActionStmt, but there // is "source" in Statement. This structure keeps the ActionStmt with the // extracted source for further use. diff --git a/flang/lib/Lower/OpenMP/Atomic.cpp b/flang/lib/Lower/OpenMP/Atomic.cpp index ed0bff04ed88..ff82a36951bf 100644 --- a/flang/lib/Lower/OpenMP/Atomic.cpp +++ b/flang/lib/Lower/OpenMP/Atomic.cpp @@ -43,179 +43,6 @@ namespace omp { using namespace Fortran::lower::omp; } -namespace { -// An example of a type that can be used to get the return value from -// the visitor: -// visitor(type_identity) -> result_type -using SomeArgType = evaluate::Type; - -struct GetProc - : public evaluate::Traverse { - using Result = const evaluate::ProcedureDesignator *; - using Base = evaluate::Traverse; - GetProc() : Base(*this) {} - - using Base::operator(); - - static Result Default() { return nullptr; } - - Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; } - static Result Combine(Result a, Result b) { return a != nullptr ? a : b; } -}; - -struct WithType { - WithType(const evaluate::DynamicType &t) : type(t) { - assert(type.category() != common::TypeCategory::Derived && - "Type cannot be a derived type"); - } - - template // - auto visit(VisitorTy &&visitor) const - -> std::invoke_result_t { - switch (type.category()) { - case common::TypeCategory::Integer: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity>{}); - case 2: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - case 16: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Unsigned: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity>{}); - case 2: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - case 16: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Real: - switch (type.kind()) { - case 2: - return visitor(llvm::type_identity>{}); - case 3: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - case 10: - return visitor(llvm::type_identity>{}); - case 16: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Complex: - switch (type.kind()) { - case 2: - return visitor(llvm::type_identity>{}); - case 3: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - case 10: - return visitor(llvm::type_identity>{}); - case 16: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Logical: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity>{}); - case 2: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - case 8: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Character: - switch (type.kind()) { - case 1: - return visitor(llvm::type_identity>{}); - case 2: - return visitor(llvm::type_identity>{}); - case 4: - return visitor(llvm::type_identity>{}); - } - break; - case common::TypeCategory::Derived: - (void)Derived; - break; - } - llvm_unreachable("Unhandled type"); - } - - const evaluate::DynamicType &type; - -private: - // Shorter names. - static constexpr auto Character = common::TypeCategory::Character; - static constexpr auto Complex = common::TypeCategory::Complex; - static constexpr auto Derived = common::TypeCategory::Derived; - static constexpr auto Integer = common::TypeCategory::Integer; - static constexpr auto Logical = common::TypeCategory::Logical; - static constexpr auto Real = common::TypeCategory::Real; - static constexpr auto Unsigned = common::TypeCategory::Unsigned; -}; - -template > -U AsRvalue(T &t) { - U copy{t}; - return std::move(copy); -} - -template -T &&AsRvalue(T &&t) { - return std::move(t); -} - -struct ArgumentReplacer - : public evaluate::Traverse { - using Base = evaluate::Traverse; - using Result = bool; - - Result Default() const { return false; } - - ArgumentReplacer(evaluate::ActualArguments &&newArgs) - : Base(*this), args_(std::move(newArgs)) {} - - using Base::operator(); - - template - Result operator()(const evaluate::FunctionRef &x) { - assert(!done_); - auto &mut = const_cast &>(x); - mut.arguments() = args_; - done_ = true; - return true; - } - - Result Combine(Result &&a, Result &&b) { return a || b; } - -private: - bool done_{false}; - evaluate::ActualArguments &&args_; -}; -} // namespace - [[maybe_unused]] static void dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) { auto whatStr = [](int k) { @@ -412,85 +239,6 @@ makeMemOrderAttr(lower::AbstractConverter &converter, return nullptr; } -static bool replaceArgs(semantics::SomeExpr &expr, - evaluate::ActualArguments &&newArgs) { - return ArgumentReplacer(std::move(newArgs))(expr); -} - -static semantics::SomeExpr makeCall(const evaluate::DynamicType &type, - const evaluate::ProcedureDesignator &proc, - const evaluate::ActualArguments &args) { - return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr { - using Type = typename llvm::remove_cvref_t::type; - return evaluate::AsGenericExpr( - evaluate::FunctionRef(AsRvalue(proc), AsRvalue(args))); - }); -} - -static const evaluate::ProcedureDesignator & -getProcedureDesignator(const semantics::SomeExpr &call) { - const evaluate::ProcedureDesignator *proc = GetProc{}(call); - assert(proc && "Call has no procedure designator"); - return *proc; -} - -static semantics::SomeExpr // -genReducedMinMax(const semantics::SomeExpr &orig, - const semantics::SomeExpr *atomArg, - const std::vector &args) { - // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...] - // One of the a_i's, say a_t, must be atomArg. - // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate - // call = min/max(a_t, tmp). - // Return "call". - - // The min/max intrinsics have 2 mandatory arguments, the rest is optional. - // Make sure that the "tmp = min/max(...)" doesn't promote an optional - // argument to a non-optional position. This could happen if a_t is at - // position 0 or 1. - if (args.size() <= 2) - return orig; - - evaluate::ActualArguments nonAtoms; - - auto AsActual = [](const semantics::SomeExpr &x) { - semantics::SomeExpr copy = x; - return evaluate::ActualArgument(std::move(copy)); - }; - // Semantic checks guarantee that the "atom" shows exactly once in the - // argument list (with potential conversions around it). - // For the first two (non-optional) arguments, if "atom" is among them, - // replace it with another occurrence of the other non-optional argument. - if (atomArg == &args[0]) { - // (atom, x, y...) -> (x, x, y...) - nonAtoms.push_back(AsActual(args[1])); - nonAtoms.push_back(AsActual(args[1])); - } else if (atomArg == &args[1]) { - // (x, atom, y...) -> (x, x, y...) - nonAtoms.push_back(AsActual(args[0])); - nonAtoms.push_back(AsActual(args[0])); - } else { - // (x, y, z...) -> unchanged - nonAtoms.push_back(AsActual(args[0])); - nonAtoms.push_back(AsActual(args[1])); - } - - // The rest of arguments are optional, so we can just skip "atom". - for (size_t i = 2, e = args.size(); i != e; ++i) { - if (atomArg != &args[i]) - nonAtoms.push_back(AsActual(args[i])); - } - - // The type of the intermediate min/max is the same as the type of its - // arguments, which may be different from the type of the original - // expression. The original expression may have additional coverts. - auto tmp = - makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms); - semantics::SomeExpr call = orig; - replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)}); - return call; -} - static mlir::Operation * // genAtomicRead(lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, mlir::Location loc, @@ -610,25 +358,6 @@ genAtomicUpdate(lower::AbstractConverter &converter, auto [opcode, args] = evaluate::GetTopLevelOperationIgnoreResizing(input); assert(!args.empty() && "Update operation without arguments"); - // Pass args as an argument to avoid capturing a structured binding. - const semantics::SomeExpr *atomArg = [&](auto &args) { - for (const semantics::SomeExpr &e : args) { - if (evaluate::IsSameOrConvertOf(e, atom)) - return &e; - } - llvm_unreachable("Atomic variable not in argument list"); - }(args); - - if (opcode == evaluate::operation::Operator::Min || - opcode == evaluate::operation::Operator::Max) { - // Min and max operations are expanded inline, so reduce them to - // operations with exactly two (non-optional) arguments. - rhs = genReducedMinMax(rhs, atomArg, args); - input = *evaluate::GetConvertInput(rhs); - std::tie(opcode, args) = - evaluate::GetTopLevelOperationIgnoreResizing(input); - atomArg = nullptr; // No longer valid. - } for (auto &arg : args) { if (!evaluate::IsSameOrConvertOf(arg, atom)) { mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc)); diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index a5fe820b1069..0c0e6158485e 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -14,6 +14,7 @@ #include "flang/Common/indirection.h" #include "flang/Evaluate/expression.h" +#include "flang/Evaluate/rewrite.h" #include "flang/Evaluate/tools.h" #include "flang/Parser/char-block.h" #include "flang/Parser/parse-tree.h" @@ -42,6 +43,8 @@ using namespace Fortran::semantics::omp; namespace operation = Fortran::evaluate::operation; +static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr); + template static bool operator!=(const evaluate::Expr &e, const evaluate::Expr &f) { return !(e == f); @@ -284,7 +287,15 @@ private: AtomicAnalysis &addOp(Op &op, int what, const std::optional &maybeAssign) { op.what = what; - op.assign = maybeAssign; + if (maybeAssign) { + if (MaybeExpr rewritten{PostSemaRewrite(atom_, maybeAssign->rhs)}) { + op.assign = evaluate::Assignment( + AsRvalue(maybeAssign->lhs), std::move(*rewritten)); + op.assign->u = std::move(maybeAssign->u); + } else { + op.assign = *maybeAssign; + } + } return *this; } @@ -1293,4 +1304,118 @@ void OmpStructureChecker::Leave(const parser::OpenMPAtomicConstruct &) { dirContext_.pop_back(); } +// Rewrite min/max: +// Min and max intrinsics in Fortran take an arbitrary number of arguments +// (two or more). The first two are mandatory, the rest is optional. That +// means that arguments beyond the first two may be optional dummy argument +// from the caller. In that case, a reference to such an argument will +// cause presence test to be emitted, which cannot go inside of the atomic +// operation. Since the atom operand must be present, rewrite the min/max +// operation in a way that avoid the presence tests in the atomic code. +// For example, in +// subroutine f(atom, x, y, z) +// integer :: atom, x +// integer, optional :: y, z +// !$omp atomic update +// atom = min(atom, x, y, z) +// end +// the min operation will become +// atom = min(atom, min(x, y, z)) +// and in the final code +// // Presence check is fine here. +// tmp = min(x, y, z) +// atomic update { +// // Both operands are mandatory, no presence check needed. +// atom = min(atom, tmp) +// } +struct MinMaxRewriter : public evaluate::rewrite::Identity { + using Id = evaluate::rewrite::Identity; + using Id::operator(); + + MinMaxRewriter(const SomeExpr &atom) : atom_(atom) {} + + static bool IsMinMax(const evaluate::ProcedureDesignator &p) { + if (auto *intrin{p.GetSpecificIntrinsic()}) { + return intrin->name == "min" || intrin->name == "max"; + } + return false; + } + + // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...] + // One of the a_i's, say a_t, must be the atom. + // Generate + // min/max(a_t, min/max(a0, a1, ... [except a_t])) + template + evaluate::Expr operator()( + evaluate::Expr &&x, const evaluate::FunctionRef &f) { + const evaluate::ProcedureDesignator &proc = f.proc(); + if (!IsMinMax(proc) || f.arguments().size() <= 2) { + return Id::operator()(std::move(x), f); + } + + // Collect arguments as SomeExpr's and find out which argument + // corresponds to atom. + const SomeExpr *atomArg{nullptr}; + std::vector args; + for (const std::optional &a : f.arguments()) { + if (!a) { + continue; + } + if (const SomeExpr *e{a->UnwrapExpr()}) { + if (evaluate::IsSameOrConvertOf(*e, atom_)) { + atomArg = e; + } + args.push_back(e); + } + } + if (!atomArg) { + return Id::operator()(std::move(x), f); + } + + evaluate::ActualArguments nonAtoms; + + auto AsActual = [](const SomeExpr &z) { + SomeExpr copy = z; + return evaluate::ActualArgument(std::move(copy)); + }; + // Semantic checks guarantee that the "atom" shows exactly once in the + // argument list (with potential conversions around it). + // For the first two (non-optional) arguments, if "atom" is among them, + // replace it with another occurrence of the other non-optional argument. + if (atomArg == args[0]) { + // (atom, x, y...) -> (x, x, y...) + nonAtoms.push_back(AsActual(*args[1])); + nonAtoms.push_back(AsActual(*args[1])); + } else if (atomArg == args[1]) { + // (x, atom, y...) -> (x, x, y...) + nonAtoms.push_back(AsActual(*args[0])); + nonAtoms.push_back(AsActual(*args[0])); + } else { + // (x, y, z...) -> unchanged + nonAtoms.push_back(AsActual(*args[0])); + nonAtoms.push_back(AsActual(*args[1])); + } + + // The rest of arguments are optional, so we can just skip "atom". + for (size_t i = 2, e = args.size(); i != e; ++i) { + if (atomArg != args[i]) + nonAtoms.push_back(AsActual(*args[i])); + } + + SomeExpr tmp = evaluate::AsGenericExpr( + evaluate::FunctionRef(AsRvalue(proc), AsRvalue(nonAtoms))); + + return evaluate::Expr(evaluate::FunctionRef( + AsRvalue(proc), {AsActual(*atomArg), AsActual(tmp)})); + } + +private: + const SomeExpr &atom_; +}; + +static MaybeExpr PostSemaRewrite(const SomeExpr &atom, const SomeExpr &expr) { + MinMaxRewriter rewriter(atom); + return evaluate::rewrite::Mutator(rewriter)(expr); +} + } // namespace Fortran::semantics