[flang] Fold EOSHIFT
Implement constant folding for the transformational intrinsic function EOSHIFT. Differential Revision: https://reviews.llvm.org/D108941
This commit is contained in:
parent
faf1c22408
commit
3fefebabe5
@ -62,6 +62,7 @@ public:
|
||||
Constant<T> *Folding(std::optional<ActualArgument> &);
|
||||
|
||||
Expr<T> CSHIFT(FunctionRef<T> &&);
|
||||
Expr<T> EOSHIFT(FunctionRef<T> &&);
|
||||
Expr<T> RESHAPE(FunctionRef<T> &&);
|
||||
|
||||
private:
|
||||
@ -619,6 +620,112 @@ template <typename T> Expr<T> Folder<T>::CSHIFT(FunctionRef<T> &&funcRef) {
|
||||
return MakeInvalidIntrinsic(std::move(funcRef));
|
||||
}
|
||||
|
||||
template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
|
||||
auto args{funcRef.arguments()};
|
||||
CHECK(args.size() == 4);
|
||||
const auto *array{UnwrapConstantValue<T>(args[0])};
|
||||
const auto *shiftExpr{UnwrapExpr<Expr<SomeInteger>>(args[1])};
|
||||
auto dim{GetInt64ArgOr(args[3], 1)};
|
||||
if (!array || !shiftExpr || !dim) {
|
||||
return Expr<T>{std::move(funcRef)};
|
||||
}
|
||||
// Apply type conversions to the shift= and boundary= arguments.
|
||||
auto convertedShift{Fold(context_,
|
||||
ConvertToType<SubscriptInteger>(Expr<SomeInteger>{*shiftExpr}))};
|
||||
const auto *shift{UnwrapConstantValue<SubscriptInteger>(convertedShift)};
|
||||
if (!shift) {
|
||||
return Expr<T>{std::move(funcRef)};
|
||||
}
|
||||
const Constant<T> *boundary{nullptr};
|
||||
std::optional<Expr<SomeType>> convertedBoundary;
|
||||
if (const auto *boundaryExpr{UnwrapExpr<Expr<SomeType>>(args[2])}) {
|
||||
convertedBoundary = Fold(context_,
|
||||
ConvertToType(array->GetType(), Expr<SomeType>{*boundaryExpr}));
|
||||
boundary = UnwrapExpr<Constant<T>>(convertedBoundary);
|
||||
if (!boundary) {
|
||||
return Expr<T>{std::move(funcRef)};
|
||||
}
|
||||
}
|
||||
// Arguments are constant
|
||||
if (*dim < 1 || *dim > array->Rank()) {
|
||||
context_.messages().Say(
|
||||
"Invalid 'dim=' argument (%jd) in EOSHIFT"_err_en_US,
|
||||
static_cast<std::intmax_t>(*dim));
|
||||
} else if (shift->Rank() > 0 && shift->Rank() != array->Rank() - 1) {
|
||||
// message already emitted from intrinsic look-up
|
||||
} else {
|
||||
int rank{array->Rank()};
|
||||
int zbDim{static_cast<int>(*dim) - 1};
|
||||
bool ok{true};
|
||||
if (shift->Rank() > 0) {
|
||||
int k{0};
|
||||
for (int j{0}; j < rank; ++j) {
|
||||
if (j != zbDim) {
|
||||
if (array->shape()[j] != shift->shape()[k]) {
|
||||
context_.messages().Say(
|
||||
"Invalid 'shift=' argument in EOSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
|
||||
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
|
||||
static_cast<std::intmax_t>(array->shape()[j]));
|
||||
ok = false;
|
||||
}
|
||||
if (boundary && array->shape()[j] != boundary->shape()[k]) {
|
||||
context_.messages().Say(
|
||||
"Invalid 'boundary=' argument in EOSHIFT; extent on dimension %d is %jd but must be %jd"_err_en_US,
|
||||
k + 1, static_cast<std::intmax_t>(shift->shape()[k]),
|
||||
static_cast<std::intmax_t>(array->shape()[j]));
|
||||
ok = false;
|
||||
}
|
||||
++k;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ok) {
|
||||
std::vector<Scalar<T>> resultElements;
|
||||
ConstantSubscripts arrayAt{array->lbounds()};
|
||||
ConstantSubscript dimLB{arrayAt[zbDim]};
|
||||
ConstantSubscript dimExtent{array->shape()[zbDim]};
|
||||
ConstantSubscripts shiftAt{shift->lbounds()};
|
||||
ConstantSubscripts boundaryAt;
|
||||
if (boundary) {
|
||||
boundaryAt = boundary->lbounds();
|
||||
}
|
||||
for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) {
|
||||
ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
|
||||
for (ConstantSubscript j{0}; j < dimExtent; ++j) {
|
||||
ConstantSubscript zbAt{shiftCount + j};
|
||||
if (zbAt >= 0 && zbAt < dimExtent) {
|
||||
arrayAt[zbDim] = dimLB + zbAt;
|
||||
resultElements.push_back(array->At(arrayAt));
|
||||
} else if (boundary) {
|
||||
resultElements.push_back(boundary->At(boundaryAt));
|
||||
} else if constexpr (T::category == TypeCategory::Integer ||
|
||||
T::category == TypeCategory::Real ||
|
||||
T::category == TypeCategory::Complex ||
|
||||
T::category == TypeCategory::Logical) {
|
||||
resultElements.emplace_back();
|
||||
} else if constexpr (T::category == TypeCategory::Character) {
|
||||
auto len{static_cast<std::size_t>(array->LEN())};
|
||||
typename Scalar<T>::value_type space{' '};
|
||||
resultElements.emplace_back(len, space);
|
||||
} else {
|
||||
DIE("no derived type boundary");
|
||||
}
|
||||
}
|
||||
arrayAt[zbDim] = dimLB + dimExtent - 1;
|
||||
array->IncrementSubscripts(arrayAt);
|
||||
shift->IncrementSubscripts(shiftAt);
|
||||
if (boundary) {
|
||||
boundary->IncrementSubscripts(boundaryAt);
|
||||
}
|
||||
}
|
||||
return Expr<T>{PackageConstant<T>(
|
||||
std::move(resultElements), *array, array->shape())};
|
||||
}
|
||||
}
|
||||
// Invalid, prevent re-folding
|
||||
return MakeInvalidIntrinsic(std::move(funcRef));
|
||||
}
|
||||
|
||||
template <typename T> Expr<T> Folder<T>::RESHAPE(FunctionRef<T> &&funcRef) {
|
||||
auto args{funcRef.arguments()};
|
||||
CHECK(args.size() == 4);
|
||||
@ -754,6 +861,8 @@ Expr<T> FoldOperation(FoldingContext &context, FunctionRef<T> &&funcRef) {
|
||||
const std::string name{intrinsic->name};
|
||||
if (name == "cshift") {
|
||||
return Folder<T>{context}.CSHIFT(std::move(funcRef));
|
||||
} else if (name == "eoshift") {
|
||||
return Folder<T>{context}.EOSHIFT(std::move(funcRef));
|
||||
} else if (name == "reshape") {
|
||||
return Folder<T>{context}.RESHAPE(std::move(funcRef));
|
||||
}
|
||||
|
||||
@ -125,7 +125,7 @@ Expr<Type<TypeCategory::Logical, KIND>> FoldIntrinsicFunction(
|
||||
name == "__builtin_ieee_support_underflow_control") {
|
||||
return Expr<T>{true};
|
||||
}
|
||||
// TODO: btest, dot_product, eoshift, is_iostat_end,
|
||||
// TODO: btest, dot_product, is_iostat_end,
|
||||
// is_iostat_eor, lge, lgt, lle, llt, logical, matmul, out_of_range,
|
||||
// parity, transfer
|
||||
return Expr<T>{std::move(funcRef)};
|
||||
|
||||
@ -385,15 +385,17 @@ static const IntrinsicInterface genericIntrinsicFunction[]{
|
||||
{"eoshift",
|
||||
{{"array", SameIntrinsic, Rank::array},
|
||||
{"shift", AnyInt, Rank::dimRemovedOrScalar},
|
||||
{"boundary", SameIntrinsic, Rank::dimReduced,
|
||||
{"boundary", SameIntrinsic, Rank::dimRemovedOrScalar,
|
||||
Optionality::optional},
|
||||
OptionalDIM},
|
||||
SameIntrinsic, Rank::conformable,
|
||||
IntrinsicClass::transformationalFunction},
|
||||
{"eoshift",
|
||||
{{"array", SameDerivedType, Rank::array},
|
||||
{"shift", AnyInt, Rank::dimReduced},
|
||||
{"boundary", SameDerivedType, Rank::dimReduced}, OptionalDIM},
|
||||
{"shift", AnyInt, Rank::dimRemovedOrScalar},
|
||||
// BOUNDARY= is not optional for derived types
|
||||
{"boundary", SameDerivedType, Rank::dimRemovedOrScalar},
|
||||
OptionalDIM},
|
||||
SameDerivedType, Rank::conformable,
|
||||
IntrinsicClass::transformationalFunction},
|
||||
{"epsilon", {{"x", SameReal, Rank::anyOrAssumedRank}}, SameReal,
|
||||
|
||||
16
flang/test/Evaluate/folding23.f90
Normal file
16
flang/test/Evaluate/folding23.f90
Normal file
@ -0,0 +1,16 @@
|
||||
! RUN: %S/test_folding.sh %s %t %flang_fc1
|
||||
! REQUIRES: shell
|
||||
! Tests folding of EOSHIFT (valid cases)
|
||||
module m
|
||||
integer, parameter :: arr(2,3) = reshape([1, 2, 3, 4, 5, 6], shape(arr))
|
||||
logical, parameter :: test_sanity = all([arr] == [1, 2, 3, 4, 5, 6])
|
||||
logical, parameter :: test_eoshift_0 = all(eoshift([1, 2, 3], 0) == [1, 2, 3])
|
||||
logical, parameter :: test_eoshift_1 = all(eoshift([1, 2, 3], 1) == [2, 3, 0])
|
||||
logical, parameter :: test_eoshift_2 = all(eoshift([1, 2, 3], -1) == [0, 1, 2])
|
||||
logical, parameter :: test_eoshift_3 = all(eoshift([1., 2., 3.], 1) == [2., 3., 0.])
|
||||
logical, parameter :: test_eoshift_4 = all(eoshift(['ab', 'cd', 'ef'], -1, 'x') == ['x ', 'ab', 'cd'])
|
||||
logical, parameter :: test_eoshift_5 = all([eoshift(arr, 1, dim=1)] == [2, 0, 4, 0, 6, 0])
|
||||
logical, parameter :: test_eoshift_6 = all([eoshift(arr, 1, dim=2)] == [3, 5, 0, 4, 6, 0])
|
||||
logical, parameter :: test_eoshift_7 = all([eoshift(arr, [1, -1, 0])] == [2, 0, 0, 3, 5, 6])
|
||||
logical, parameter :: test_eoshift_8 = all([eoshift(arr, [1, -1], dim=2)] == [3, 5, 0, 0, 2, 4])
|
||||
end module
|
||||
Loading…
x
Reference in New Issue
Block a user