[flang][NFC] Move new code to right place (#144551)

Some new code was added to flang/Semantics that only depends on
facilities in flang/Evaluate. Move it into Evaluate and clean up some
minor stylistic problems.
This commit is contained in:
Peter Klausler 2025-06-19 13:42:46 -07:00 committed by GitHub
parent 03692aa404
commit 9fd22cb56d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 467 additions and 482 deletions

View File

@ -1389,6 +1389,154 @@ inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
}
// Checks whether the symbol on the LHS is present in the RHS expression.
bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs);
namespace operation {
enum class Operator {
Unknown,
Add,
And,
Associated,
Call,
Constant,
Convert,
Div,
Eq,
Eqv,
False,
Ge,
Gt,
Identity,
Intrinsic,
Le,
Lt,
Max,
Min,
Mul,
Ne,
Neqv,
Not,
Or,
Pow,
Resize, // Convert within the same TypeCategory
Sub,
True,
};
std::string ToString(Operator op);
template <typename... Ts, int Kind>
Operator OperationCode(
const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
switch (op.derived().logicalOperator) {
case common::LogicalOperator::And:
return Operator::And;
case common::LogicalOperator::Or:
return Operator::Or;
case common::LogicalOperator::Eqv:
return Operator::Eqv;
case common::LogicalOperator::Neqv:
return Operator::Neqv;
case common::LogicalOperator::Not:
return Operator::Not;
}
return Operator::Unknown;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
switch (op.derived().opr) {
case common::RelationalOperator::LT:
return Operator::Lt;
case common::RelationalOperator::LE:
return Operator::Le;
case common::RelationalOperator::EQ:
return Operator::Eq;
case common::RelationalOperator::NE:
return Operator::Ne;
case common::RelationalOperator::GE:
return Operator::Ge;
case common::RelationalOperator::GT:
return Operator::Gt;
}
return Operator::Unknown;
}
template <typename T, typename... Ts>
Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
return Operator::Add;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
return Operator::Sub;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
return Operator::Mul;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
return Operator::Div;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
return Operator::Pow;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
return Operator::Pow;
}
template <typename T, common::TypeCategory C, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
if constexpr (C == T::category) {
return Operator::Resize;
} else {
return Operator::Convert;
}
}
template <typename T> Operator OperationCode(const evaluate::Constant<T> &x) {
return Operator::Constant;
}
template <typename T> Operator OperationCode(const T &) {
return Operator::Unknown;
}
Operator OperationCode(const evaluate::ProcedureDesignator &proc);
} // namespace operation
// Return information about the top-level operation (ignoring parentheses):
// the operation code and the list of arguments.
std::pair<operation::Operator, std::vector<Expr<SomeType>>>
GetTopLevelOperation(const Expr<SomeType> &expr);
// Check if expr is same as x, or a sequence of Convert operations on x.
bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x);
// Strip away any top-level Convert operations (if any exist) and return
// the input value. A ComplexConstructor(x, 0) is also considered as a
// convert operation.
// If the input is not Operation, Designator, FunctionRef or Constant,
// it returns std::nullopt.
std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x);
} // namespace Fortran::evaluate
namespace Fortran::semantics {

View File

@ -756,154 +756,5 @@ std::string GetCommonBlockObjectName(const Symbol &, bool underscoring);
// Check for ambiguous USE associations
bool HadUseError(SemanticsContext &, SourceName at, const Symbol *);
// Checks whether the symbol on the LHS is present in the RHS expression.
bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs);
namespace operation {
enum class Operator {
Unknown,
Add,
And,
Associated,
Call,
Constant,
Convert,
Div,
Eq,
Eqv,
False,
Ge,
Gt,
Identity,
Intrinsic,
Le,
Lt,
Max,
Min,
Mul,
Ne,
Neqv,
Not,
Or,
Pow,
Resize, // Convert within the same TypeCategory
Sub,
True,
};
std::string ToString(Operator op);
template <typename... Ts, int Kind>
Operator OperationCode(
const evaluate::Operation<evaluate::LogicalOperation<Kind>, Ts...> &op) {
switch (op.derived().logicalOperator) {
case common::LogicalOperator::And:
return Operator::And;
case common::LogicalOperator::Or:
return Operator::Or;
case common::LogicalOperator::Eqv:
return Operator::Eqv;
case common::LogicalOperator::Neqv:
return Operator::Neqv;
case common::LogicalOperator::Not:
return Operator::Not;
}
return Operator::Unknown;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Relational<T>, Ts...> &op) {
switch (op.derived().opr) {
case common::RelationalOperator::LT:
return Operator::Lt;
case common::RelationalOperator::LE:
return Operator::Le;
case common::RelationalOperator::EQ:
return Operator::Eq;
case common::RelationalOperator::NE:
return Operator::Ne;
case common::RelationalOperator::GE:
return Operator::Ge;
case common::RelationalOperator::GT:
return Operator::Gt;
}
return Operator::Unknown;
}
template <typename T, typename... Ts>
Operator OperationCode(const evaluate::Operation<evaluate::Add<T>, Ts...> &op) {
return Operator::Add;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Subtract<T>, Ts...> &op) {
return Operator::Sub;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Multiply<T>, Ts...> &op) {
return Operator::Mul;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Divide<T>, Ts...> &op) {
return Operator::Div;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Power<T>, Ts...> &op) {
return Operator::Pow;
}
template <typename T, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::RealToIntPower<T>, Ts...> &op) {
return Operator::Pow;
}
template <typename T, common::TypeCategory C, typename... Ts>
Operator OperationCode(
const evaluate::Operation<evaluate::Convert<T, C>, Ts...> &op) {
if constexpr (C == T::category) {
return Operator::Resize;
} else {
return Operator::Convert;
}
}
template <typename T> //
Operator OperationCode(const evaluate::Constant<T> &x) {
return Operator::Constant;
}
template <typename T> //
Operator OperationCode(const T &) {
return Operator::Unknown;
}
Operator OperationCode(const evaluate::ProcedureDesignator &proc);
} // namespace operation
/// Return information about the top-level operation (ignoring parentheses):
/// the operation code and the list of arguments.
std::pair<operation::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
const SomeExpr &expr);
/// Check if expr is same as x, or a sequence of Convert operations on x.
bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x);
/// Strip away any top-level Convert operations (if any exist) and return
/// the input value. A ComplexConstructor(x, 0) is also considered as a
/// convert operation.
/// If the input is not Operation, Designator, FunctionRef or Constant,
/// it returns std::nullopt.
MaybeExpr GetConvertInput(const SomeExpr &x);
} // namespace Fortran::semantics
#endif // FORTRAN_SEMANTICS_TOOLS_H_

View File

@ -13,6 +13,7 @@
#include "flang/Evaluate/traverse.h"
#include "flang/Parser/message.h"
#include "flang/Semantics/tools.h"
#include "llvm/ADT/StringSwitch.h"
#include <algorithm>
#include <variant>
@ -1595,6 +1596,316 @@ bool CheckForCoindexedObject(parser::ContextualMessages &messages,
}
}
bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs) {
if (lhs && rhs) {
if (SymbolVector lhsSymbols{GetSymbolVector(*lhs)}; !lhsSymbols.empty()) {
const Symbol &first{*lhsSymbols.front()};
for (const Symbol &symbol : GetSymbolVector(*rhs)) {
if (first == symbol) {
return true;
}
}
}
}
return false;
}
namespace operation {
template <typename T> Expr<SomeType> AsSomeExpr(const T &x) {
return AsGenericExpr(common::Clone(x));
}
template <bool IgnoreResizingConverts>
struct ArgumentExtractor
: public Traverse<ArgumentExtractor<IgnoreResizingConverts>,
std::pair<operation::Operator, std::vector<Expr<SomeType>>>, false> {
using Arguments = std::vector<Expr<SomeType>>;
using Result = std::pair<operation::Operator, Arguments>;
using Base =
Traverse<ArgumentExtractor<IgnoreResizingConverts>, Result, false>;
static constexpr auto IgnoreResizes{IgnoreResizingConverts};
static constexpr auto Logical{common::TypeCategory::Logical};
ArgumentExtractor() : Base(*this) {}
Result Default() const { return {}; }
using Base::operator();
template <int Kind>
Result operator()(const Constant<Type<Logical, Kind>> &x) const {
if (const auto &val{x.GetScalarValue()}) {
return val->IsTrue()
? std::make_pair(operation::Operator::True, Arguments{})
: std::make_pair(operation::Operator::False, Arguments{});
}
return Default();
}
template <typename R> Result operator()(const FunctionRef<R> &x) const {
Result result{operation::OperationCode(x.proc()), {}};
for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
if (auto *e{x.UnwrapArgExpr(i)}) {
result.second.push_back(*e);
}
}
return result;
}
template <typename D, typename R, typename... Os>
Result operator()(const Operation<D, R, Os...> &x) const {
if constexpr (std::is_same_v<D, Parentheses<R>>) {
// Ignore top-level parentheses.
return (*this)(x.template operand<0>());
}
if constexpr (IgnoreResizes && std::is_same_v<D, Convert<R, R::category>>) {
// Ignore conversions within the same category.
// Atomic operations on int(kind=1) may be implicitly widened
// to int(kind=4) for example.
return (*this)(x.template operand<0>());
} else {
return std::make_pair(operation::OperationCode(x),
OperationArgs(x, std::index_sequence_for<Os...>{}));
}
}
template <typename T> Result operator()(const Designator<T> &x) const {
return {operation::Operator::Identity, {AsSomeExpr(x)}};
}
template <typename T> Result operator()(const Constant<T> &x) const {
return {operation::Operator::Identity, {AsSomeExpr(x)}};
}
template <typename... Rs>
Result Combine(Result &&result, Rs &&...results) const {
// There shouldn't be any combining needed, since we're stopping the
// traversal at the top-level operation, but implement one that picks
// the first non-empty result.
if constexpr (sizeof...(Rs) == 0) {
return std::move(result);
} else {
if (!result.second.empty()) {
return std::move(result);
} else {
return Combine(std::move(results)...);
}
}
}
private:
template <typename D, typename R, typename... Os, size_t... Is>
Arguments OperationArgs(
const Operation<D, R, Os...> &x, std::index_sequence<Is...>) const {
return Arguments{Expr<SomeType>(x.template operand<Is>())...};
}
};
} // namespace operation
std::string operation::ToString(operation::Operator op) {
switch (op) {
case Operator::Unknown:
return "??";
case Operator::Add:
return "+";
case Operator::And:
return "AND";
case Operator::Associated:
return "ASSOCIATED";
case Operator::Call:
return "function-call";
case Operator::Constant:
return "constant";
case Operator::Convert:
return "type-conversion";
case Operator::Div:
return "/";
case Operator::Eq:
return "==";
case Operator::Eqv:
return "EQV";
case Operator::False:
return ".FALSE.";
case Operator::Ge:
return ">=";
case Operator::Gt:
return ">";
case Operator::Identity:
return "identity";
case Operator::Intrinsic:
return "intrinsic";
case Operator::Le:
return "<=";
case Operator::Lt:
return "<";
case Operator::Max:
return "MAX";
case Operator::Min:
return "MIN";
case Operator::Mul:
return "*";
case Operator::Ne:
return "/=";
case Operator::Neqv:
return "NEQV/EOR";
case Operator::Not:
return "NOT";
case Operator::Or:
return "OR";
case Operator::Pow:
return "**";
case Operator::Resize:
return "resize";
case Operator::Sub:
return "-";
case Operator::True:
return ".TRUE.";
}
llvm_unreachable("Unhandler operator");
}
operation::Operator operation::OperationCode(const ProcedureDesignator &proc) {
Operator code{llvm::StringSwitch<Operator>(proc.GetName())
.Case("associated", Operator::Associated)
.Case("min", Operator::Min)
.Case("max", Operator::Max)
.Case("iand", Operator::And)
.Case("ior", Operator::Or)
.Case("ieor", Operator::Neqv)
.Default(Operator::Call)};
if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
return Operator::Intrinsic;
}
return code;
}
std::pair<operation::Operator, std::vector<Expr<SomeType>>>
GetTopLevelOperation(const Expr<SomeType> &expr) {
return operation::ArgumentExtractor<true>{}(expr);
}
namespace operation {
struct ConvertCollector
: public Traverse<ConvertCollector,
std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>,
false> {
using Result =
std::pair<std::optional<Expr<SomeType>>, std::vector<DynamicType>>;
using Base = Traverse<ConvertCollector, Result, false>;
ConvertCollector() : Base(*this) {}
Result Default() const { return {}; }
using Base::operator();
template <typename T> Result operator()(const Designator<T> &x) const {
return {AsSomeExpr(x), {}};
}
template <typename T> Result operator()(const FunctionRef<T> &x) const {
return {AsSomeExpr(x), {}};
}
template <typename T> Result operator()(const Constant<T> &x) const {
return {AsSomeExpr(x), {}};
}
template <typename D, typename R, typename... Os>
Result operator()(const Operation<D, R, Os...> &x) const {
if constexpr (std::is_same_v<D, Parentheses<R>>) {
// Ignore parentheses.
return (*this)(x.template operand<0>());
} else if constexpr (is_convert_v<D>) {
// Convert should always have a typed result, so it should be safe to
// dereference x.GetType().
return Combine(
{std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
} else if constexpr (is_complex_constructor_v<D>) {
// This is a conversion iff the imaginary operand is 0.
if (IsZero(x.template operand<1>())) {
return Combine(
{std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
} else {
return {AsSomeExpr(x.derived()), {}};
}
} else {
return {AsSomeExpr(x.derived()), {}};
}
}
template <typename... Rs>
Result Combine(Result &&result, Rs &&...results) const {
Result v(std::move(result));
auto setValue{[](std::optional<Expr<SomeType>> &x,
std::optional<Expr<SomeType>> &&y) {
assert((!x.has_value() || !y.has_value()) && "Multiple designators");
if (!x.has_value()) {
x = std::move(y);
}
}};
auto moveAppend{[](auto &accum, auto &&other) {
for (auto &&s : other) {
accum.push_back(std::move(s));
}
}};
(setValue(v.first, std::move(results).first), ...);
(moveAppend(v.second, std::move(results).second), ...);
return v;
}
private:
template <typename A> static bool IsZero(const A &x) { return false; }
template <typename T> static bool IsZero(const Expr<T> &x) {
return common::visit([](auto &&s) { return IsZero(s); }, x.u);
}
template <typename T> static bool IsZero(const Constant<T> &x) {
if (auto &&maybeScalar{x.GetScalarValue()}) {
return maybeScalar->IsZero();
} else {
return false;
}
}
template <typename T> struct is_convert {
static constexpr bool value{false};
};
template <typename T, common::TypeCategory C>
struct is_convert<Convert<T, C>> {
static constexpr bool value{true};
};
template <int K> struct is_convert<ComplexComponent<K>> {
// Conversion from complex to real.
static constexpr bool value{true};
};
template <typename T>
static constexpr bool is_convert_v{is_convert<T>::value};
template <typename T> struct is_complex_constructor {
static constexpr bool value{false};
};
template <int K> struct is_complex_constructor<ComplexConstructor<K>> {
static constexpr bool value{true};
};
template <typename T>
static constexpr bool is_complex_constructor_v{
is_complex_constructor<T>::value};
};
} // namespace operation
std::optional<Expr<SomeType>> GetConvertInput(const Expr<SomeType> &x) {
// This returns Expr<SomeType>{x} when x is a designator/functionref/constant.
return operation::ConvertCollector{}(x).first;
}
bool IsSameOrConvertOf(const Expr<SomeType> &expr, const Expr<SomeType> &x) {
// Check if expr is same as x, or a sequence of Convert operations on x.
if (expr == x) {
return true;
} else if (auto maybe{GetConvertInput(expr)}) {
return *maybe == x;
} else {
return false;
}
}
} // namespace Fortran::evaluate
namespace Fortran::semantics {

View File

@ -654,7 +654,7 @@ void genAtomicCapture(Fortran::lower::AbstractConverter &converter,
mlir::Block &block = atomicCaptureOp->getRegion(0).back();
firOpBuilder.setInsertionPointToStart(&block);
if (Fortran::parser::CheckForSingleVariableOnRHS(stmt1)) {
if (Fortran::semantics::CheckForSymbolMatch(
if (Fortran::evaluate::CheckForSymbolMatch(
Fortran::semantics::GetExpr(stmt2Var),
Fortran::semantics::GetExpr(stmt2Expr))) {
// Atomic capture construct is of the form [capture-stmt, update-stmt]

View File

@ -2934,11 +2934,12 @@ genAtomicUpdate(lower::AbstractConverter &converter,
mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
// This must exist by now.
SomeExpr input = *semantics::GetConvertInput(assign.rhs);
std::vector<SomeExpr> args{semantics::GetTopLevelOperation(input).second};
SomeExpr input = *Fortran::evaluate::GetConvertInput(assign.rhs);
std::vector<SomeExpr> args{
Fortran::evaluate::GetTopLevelOperation(input).second};
assert(!args.empty() && "Update operation without arguments");
for (auto &arg : args) {
if (!semantics::IsSameOrConvertOf(arg, atom)) {
if (!Fortran::evaluate::IsSameOrConvertOf(arg, atom)) {
mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
overrides.try_emplace(&arg, val);
}

View File

@ -12,6 +12,7 @@
#include "flang/Evaluate/check-expression.h"
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/shape.h"
#include "flang/Evaluate/tools.h"
#include "flang/Evaluate/type.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/expression.h"
@ -2987,6 +2988,8 @@ static bool IsPointerAssignment(const evaluate::Assignment &x) {
std::holds_alternative<evaluate::Assignment::BoundsRemapping>(x.u);
}
namespace operation = Fortran::evaluate::operation;
static bool IsCheckForAssociated(const SomeExpr &cond) {
return GetTopLevelOperation(cond).first == operation::Operator::Associated;
}

View File

@ -17,7 +17,6 @@
#include "flang/Semantics/tools.h"
#include "flang/Semantics/type.h"
#include "flang/Support/Fortran.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <set>
@ -1789,332 +1788,4 @@ bool HadUseError(
}
}
bool CheckForSymbolMatch(const SomeExpr *lhs, const SomeExpr *rhs) {
if (lhs && rhs) {
if (SymbolVector lhsSymbols{evaluate::GetSymbolVector(*lhs)};
!lhsSymbols.empty()) {
const Symbol &first{*lhsSymbols.front()};
for (const Symbol &symbol : evaluate::GetSymbolVector(*rhs)) {
if (first == symbol) {
return true;
}
}
}
}
return false;
}
namespace operation {
template <typename T> //
SomeExpr asSomeExpr(const T &x) {
auto copy{x};
return AsGenericExpr(std::move(copy));
}
template <bool IgnoreResizingConverts> //
struct ArgumentExtractor
: public evaluate::Traverse<ArgumentExtractor<IgnoreResizingConverts>,
std::pair<operation::Operator, std::vector<SomeExpr>>, false> {
using Arguments = std::vector<SomeExpr>;
using Result = std::pair<operation::Operator, Arguments>;
using Base = evaluate::Traverse<ArgumentExtractor<IgnoreResizingConverts>,
Result, false>;
static constexpr auto IgnoreResizes = IgnoreResizingConverts;
static constexpr auto Logical = common::TypeCategory::Logical;
ArgumentExtractor() : Base(*this) {}
Result Default() const { return {}; }
using Base::operator();
template <int Kind> //
Result operator()(
const evaluate::Constant<evaluate::Type<Logical, Kind>> &x) const {
if (const auto &val{x.GetScalarValue()}) {
return val->IsTrue()
? std::make_pair(operation::Operator::True, Arguments{})
: std::make_pair(operation::Operator::False, Arguments{});
}
return Default();
}
template <typename R> //
Result operator()(const evaluate::FunctionRef<R> &x) const {
Result result{operation::OperationCode(x.proc()), {}};
for (size_t i{0}, e{x.arguments().size()}; i != e; ++i) {
if (auto *e{x.UnwrapArgExpr(i)}) {
result.second.push_back(*e);
}
}
return result;
}
template <typename D, typename R, typename... Os>
Result operator()(const evaluate::Operation<D, R, Os...> &x) const {
if constexpr (std::is_same_v<D, evaluate::Parentheses<R>>) {
// Ignore top-level parentheses.
return (*this)(x.template operand<0>());
}
if constexpr (IgnoreResizes &&
std::is_same_v<D, evaluate::Convert<R, R::category>>) {
// Ignore conversions within the same category.
// Atomic operations on int(kind=1) may be implicitly widened
// to int(kind=4) for example.
return (*this)(x.template operand<0>());
} else {
return std::make_pair(operation::OperationCode(x),
OperationArgs(x, std::index_sequence_for<Os...>{}));
}
}
template <typename T> //
Result operator()(const evaluate::Designator<T> &x) const {
return {operation::Operator::Identity, {asSomeExpr(x)}};
}
template <typename T> //
Result operator()(const evaluate::Constant<T> &x) const {
return {operation::Operator::Identity, {asSomeExpr(x)}};
}
template <typename... Rs> //
Result Combine(Result &&result, Rs &&...results) const {
// There shouldn't be any combining needed, since we're stopping the
// traversal at the top-level operation, but implement one that picks
// the first non-empty result.
if constexpr (sizeof...(Rs) == 0) {
return std::move(result);
} else {
if (!result.second.empty()) {
return std::move(result);
} else {
return Combine(std::move(results)...);
}
}
}
private:
template <typename D, typename R, typename... Os, size_t... Is>
Arguments OperationArgs(const evaluate::Operation<D, R, Os...> &x,
std::index_sequence<Is...>) const {
return Arguments{SomeExpr(x.template operand<Is>())...};
}
};
} // namespace operation
std::string operation::ToString(operation::Operator op) {
switch (op) {
case Operator::Unknown:
return "??";
case Operator::Add:
return "+";
case Operator::And:
return "AND";
case Operator::Associated:
return "ASSOCIATED";
case Operator::Call:
return "function-call";
case Operator::Constant:
return "constant";
case Operator::Convert:
return "type-conversion";
case Operator::Div:
return "/";
case Operator::Eq:
return "==";
case Operator::Eqv:
return "EQV";
case Operator::False:
return ".FALSE.";
case Operator::Ge:
return ">=";
case Operator::Gt:
return ">";
case Operator::Identity:
return "identity";
case Operator::Intrinsic:
return "intrinsic";
case Operator::Le:
return "<=";
case Operator::Lt:
return "<";
case Operator::Max:
return "MAX";
case Operator::Min:
return "MIN";
case Operator::Mul:
return "*";
case Operator::Ne:
return "/=";
case Operator::Neqv:
return "NEQV/EOR";
case Operator::Not:
return "NOT";
case Operator::Or:
return "OR";
case Operator::Pow:
return "**";
case Operator::Resize:
return "resize";
case Operator::Sub:
return "-";
case Operator::True:
return ".TRUE.";
}
llvm_unreachable("Unhandler operator");
}
operation::Operator operation::OperationCode(
const evaluate::ProcedureDesignator &proc) {
Operator code = llvm::StringSwitch<Operator>(proc.GetName())
.Case("associated", Operator::Associated)
.Case("min", Operator::Min)
.Case("max", Operator::Max)
.Case("iand", Operator::And)
.Case("ior", Operator::Or)
.Case("ieor", Operator::Neqv)
.Default(Operator::Call);
if (code == Operator::Call && proc.GetSpecificIntrinsic()) {
return Operator::Intrinsic;
}
return code;
}
std::pair<operation::Operator, std::vector<SomeExpr>> GetTopLevelOperation(
const SomeExpr &expr) {
return operation::ArgumentExtractor<true>{}(expr);
}
namespace operation {
struct ConvertCollector
: public evaluate::Traverse<ConvertCollector,
std::pair<MaybeExpr, std::vector<evaluate::DynamicType>>, false> {
using Result = std::pair<MaybeExpr, std::vector<evaluate::DynamicType>>;
using Base = evaluate::Traverse<ConvertCollector, Result, false>;
ConvertCollector() : Base(*this) {}
Result Default() const { return {}; }
using Base::operator();
template <typename T> //
Result operator()(const evaluate::Designator<T> &x) const {
return {asSomeExpr(x), {}};
}
template <typename T> //
Result operator()(const evaluate::FunctionRef<T> &x) const {
return {asSomeExpr(x), {}};
}
template <typename T> //
Result operator()(const evaluate::Constant<T> &x) const {
return {asSomeExpr(x), {}};
}
template <typename D, typename R, typename... Os>
Result operator()(const evaluate::Operation<D, R, Os...> &x) const {
if constexpr (std::is_same_v<D, evaluate::Parentheses<R>>) {
// Ignore parentheses.
return (*this)(x.template operand<0>());
} else if constexpr (is_convert_v<D>) {
// Convert should always have a typed result, so it should be safe to
// dereference x.GetType().
return Combine(
{std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
} else if constexpr (is_complex_constructor_v<D>) {
// This is a conversion iff the imaginary operand is 0.
if (IsZero(x.template operand<1>())) {
return Combine(
{std::nullopt, {*x.GetType()}}, (*this)(x.template operand<0>()));
} else {
return {asSomeExpr(x.derived()), {}};
}
} else {
return {asSomeExpr(x.derived()), {}};
}
}
template <typename... Rs> //
Result Combine(Result &&result, Rs &&...results) const {
Result v(std::move(result));
auto setValue{[](MaybeExpr &x, MaybeExpr &&y) {
assert((!x.has_value() || !y.has_value()) && "Multiple designators");
if (!x.has_value()) {
x = std::move(y);
}
}};
auto moveAppend{[](auto &accum, auto &&other) {
for (auto &&s : other) {
accum.push_back(std::move(s));
}
}};
(setValue(v.first, std::move(results).first), ...);
(moveAppend(v.second, std::move(results).second), ...);
return v;
}
private:
template <typename T> //
static bool IsZero(const T &x) {
return false;
}
template <typename T> //
static bool IsZero(const evaluate::Expr<T> &x) {
return common::visit([](auto &&s) { return IsZero(s); }, x.u);
}
template <typename T> //
static bool IsZero(const evaluate::Constant<T> &x) {
if (auto &&maybeScalar{x.GetScalarValue()}) {
return maybeScalar->IsZero();
} else {
return false;
}
}
template <typename T> //
struct is_convert {
static constexpr bool value{false};
};
template <typename T, common::TypeCategory C> //
struct is_convert<evaluate::Convert<T, C>> {
static constexpr bool value{true};
};
template <int K> //
struct is_convert<evaluate::ComplexComponent<K>> {
// Conversion from complex to real.
static constexpr bool value{true};
};
template <typename T> //
static constexpr bool is_convert_v = is_convert<T>::value;
template <typename T> //
struct is_complex_constructor {
static constexpr bool value{false};
};
template <int K> //
struct is_complex_constructor<evaluate::ComplexConstructor<K>> {
static constexpr bool value{true};
};
template <typename T> //
static constexpr bool is_complex_constructor_v =
is_complex_constructor<T>::value;
};
} // namespace operation
MaybeExpr GetConvertInput(const SomeExpr &x) {
// This returns SomeExpr(x) when x is a designator/functionref/constant.
return operation::ConvertCollector{}(x).first;
}
bool IsSameOrConvertOf(const SomeExpr &expr, const SomeExpr &x) {
// Check if expr is same as x, or a sequence of Convert operations on x.
if (expr == x) {
return true;
} else if (auto maybe{GetConvertInput(expr)}) {
return *maybe == x;
} else {
return false;
}
}
} // namespace Fortran::semantics