llvm-project/flang/lib/Lower/Support/ReductionProcessor.cpp
Matt e80604a641
[flang][OpenMP] Support user-defined declare reduction with derived types (#184897)
Fix lowering of `!$omp declare reduction` for intrinsic operators
applied
to user-defined derived types (e.g., `+` on `type(t)`). Previously, this
hit a TODO in `ReductionProcessor::getReductionInitValue` because the
code
tried to compute an init value for a non-predefined type, when it should
instead use the initializer region from the `DeclareReductionOp`.

This fixes the issue #176278: [Flang][OpenMP] Compilation error when
type-list in declare reduction directive is derived type name.

The root cause was a naming mismatch: `genOMP` for
`OpenMPDeclareReductionConstruct` used a raw operator string (e.g.,
"Add")
as the reduction name, while `processReductionArguments` at the use site
computed a canonical name via `getReductionName` (e.g.,
"add_reduction_byref_rec__QFTt"). The `lookupSymbol` in
`createDeclareReductionHelper` never found the already-created op, so it
fell through to `createDeclareReduction` which called
`getReductionInitValue`
with the derived type and hit the TODO.

The fix has three parts:

1. Consistent names: In `genOMP` for `OpenMPDeclareReductionConstruct`,
compute
the reduction name using the same `getReductionName` scheme that
`processReductionArguments` uses, so both sites produce identical symbol
names.
For intrinsic operators, this maps through `ReductionIdentifier` to get
the
canonical name. For user-defined named reductions, the raw symbol name
is used
directly, matching the existing custom-reduction lookup path.

2. Reuse reduction: In `processReductionArguments`, when an intrinsic
operator
reduction is requested, check whether a user-defined declare reduction
already
exists under that canonical name before attempting to create a new one.
If
found, reuse it. This avoids calling `createDeclareReduction` (and thus
`getReductionInitValue`) for types that have user-provided initializers.

3. Reference semantics: Change `doReductionByRef` to return true for
derived
types. Previously it returned false for both trivial and derived types,
treating
derived types as by-val. This is incorrect for user-defined combiners
that
operate on components via side-effects (e.g., `omp_out%x = omp_out%x +
omp_in%x`): the combiner mutates `omp_out` in place and doesn't produce
a
whole-struct value, so `convertExprToValue` returns the component type
(`i32`) rather than the struct type, causing a type mismatch in the
`omp.yield`. By-ref is the correct model: the combiner stores into the
lhs reference and yields it.

The combiner callback in `processReductionCombiner` is also updated to
handle the by-ref derived-type case: when the combiner result type
doesn't match the element type (as happens with component-level
assignments), the store is skipped since the assignment already wrote
into omp_out as a side-effect, and only the lhs reference is yielded.

Tests updates:
- Update declare-reduction-intrinsic-op.f90 from a negative test
(checking
for the TODO error) to a positive test checking the generated MLIR.
- Update omp-declare-reduction-derivedtype.f90 CHECK lines to match the
reference semantics fix: the `declare_reduction` now has type
`!fir.ref<...>`
with a `byref_element_type` attribute, an alloc region, a two-argument
init
region, and a combiner that stores into the lhs and yields the
reference. The function body checks for initme and mycombine are
unchanged in substance but use literal type names instead of a regex
capture to avoid greedy matching issues with nested angle brackets.

Remaining work: declare reduction without an initializer clause is not
yet
supported. I plan to address that subsequently.

Assisted-by: Claude Opus 4.6.

Note: Relied on LLM (Claude Opus 4.6) to help navigate the Flang APIs
and assist
with the corresponding boilerplate code & tests updates; in particular:
in order
to get the aforementioned consistent naming, in
`ReductionProcessor::getReductionName` I had to get rid of
`parser::DefinedOperator::EnumToString` and instead introduce
`getRedIdFromParserIntrOp` (which does the conversion manually; just to
make
sure I haven't missed anything: is there no existing conversion
function?
AFAICT, there is none, but I might've missed it). In any case, feedback
welcome!

---------

Co-authored-by: Matt P. Dziubinski <matt-p.dziubinski@hpe.com>
2026-03-26 15:48:47 +00:00

884 lines
35 KiB
C++

//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/Support/PrivateReductionUtils.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>
static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
llvm::cl::desc("Pass all reduction arguments by reference"),
llvm::cl::Hidden);
using ReductionModifier =
Fortran::lower::omp::clause::Reduction::ReductionModifier;
namespace Fortran {
namespace lower {
namespace omp {
// explicit template declarations
template bool ReductionProcessor::processReductionArguments<
mlir::omp::DeclareReductionOp, omp::clause::ReductionOperatorList>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::ReductionOperatorList &redOperatorList,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
template bool ReductionProcessor::processReductionArguments<
fir::DeclareReductionOp, llvm::SmallVector<fir::ReduceOperationEnum>>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const llvm::SmallVector<fir::ReduceOperationEnum> &redOperatorList,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
template mlir::omp::DeclareReductionOp
ReductionProcessor::createDeclareReduction<mlir::omp::DeclareReductionOp>(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef);
template fir::DeclareReductionOp
ReductionProcessor::createDeclareReduction<fir::DeclareReductionOp>(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef);
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
getRealName(pd.v.sym()).ToString())
.Case("max", ReductionIdentifier::MAX)
.Case("min", ReductionIdentifier::MIN)
.Case("iand", ReductionIdentifier::IAND)
.Case("ior", ReductionIdentifier::IOR)
.Case("ieor", ReductionIdentifier::IEOR)
.Default(std::nullopt);
assert(redType && "Invalid Reduction");
return *redType;
}
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
switch (intrinsicOp) {
case omp::clause::DefinedOperator::IntrinsicOperator::Add:
return ReductionIdentifier::ADD;
case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
return ReductionIdentifier::SUBTRACT;
case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
return ReductionIdentifier::MULTIPLY;
case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return ReductionIdentifier::AND;
case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return ReductionIdentifier::EQV;
case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return ReductionIdentifier::OR;
case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return ReductionIdentifier::NEQV;
default:
llvm_unreachable("unexpected intrinsic operator in reduction");
}
}
ReductionProcessor::ReductionIdentifier
ReductionProcessor::getReductionType(const fir::ReduceOperationEnum &redOp) {
switch (redOp) {
case fir::ReduceOperationEnum::Add:
return ReductionIdentifier::ADD;
case fir::ReduceOperationEnum::Multiply:
return ReductionIdentifier::MULTIPLY;
case fir::ReduceOperationEnum::AND:
return ReductionIdentifier::AND;
case fir::ReduceOperationEnum::OR:
return ReductionIdentifier::OR;
case fir::ReduceOperationEnum::EQV:
return ReductionIdentifier::EQV;
case fir::ReduceOperationEnum::NEQV:
return ReductionIdentifier::NEQV;
case fir::ReduceOperationEnum::IAND:
return ReductionIdentifier::IAND;
case fir::ReduceOperationEnum::IEOR:
return ReductionIdentifier::IEOR;
case fir::ReduceOperationEnum::IOR:
return ReductionIdentifier::IOR;
case fir::ReduceOperationEnum::MAX:
return ReductionIdentifier::MAX;
case fir::ReduceOperationEnum::MIN:
return ReductionIdentifier::MIN;
}
llvm_unreachable("Unhandled ReductionIdentifier case");
}
bool ReductionProcessor::supportedIntrinsicProcReduction(
const omp::clause::ProcedureDesignator &pd) {
semantics::Symbol *sym = pd.v.sym();
if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC))
return false;
auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
.Case("max", true)
.Case("min", true)
.Case("iand", true)
.Case("ior", true)
.Case("ieor", true)
.Default(false);
return redType;
}
std::string
ReductionProcessor::getReductionName(llvm::StringRef name,
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef) {
ty = fir::unwrapRefType(ty);
// extra string to distinguish reduction functions for variables passed by
// reference
llvm::StringRef byrefAddition{""};
if (isByRef)
byrefAddition = "_byref";
return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
}
std::string
ReductionProcessor::getReductionName(ReductionIdentifier redId,
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef) {
std::string reductionName;
switch (redId) {
case ReductionIdentifier::ADD:
reductionName = "add_reduction";
break;
case ReductionIdentifier::MULTIPLY:
reductionName = "multiply_reduction";
break;
case ReductionIdentifier::AND:
reductionName = "and_reduction";
break;
case ReductionIdentifier::EQV:
reductionName = "eqv_reduction";
break;
case ReductionIdentifier::OR:
reductionName = "or_reduction";
break;
case ReductionIdentifier::NEQV:
reductionName = "neqv_reduction";
break;
default:
reductionName = "other_reduction";
break;
}
return getReductionName(reductionName, kindMap, ty, isByRef);
}
mlir::Value
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
ReductionIdentifier redId,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
!fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
TODO(loc, "Reduction of some types is not supported");
switch (redId) {
case ReductionIdentifier::MAX: {
if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
}
unsigned bits = type.getIntOrFloatBitWidth();
int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, minInt);
}
case ReductionIdentifier::MIN: {
if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
}
unsigned bits = type.getIntOrFloatBitWidth();
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, maxInt);
}
case ReductionIdentifier::IOR: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, zeroInt);
}
case ReductionIdentifier::IEOR: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, zeroInt);
}
case ReductionIdentifier::IAND: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, allOnInt);
}
case ReductionIdentifier::ADD:
case ReductionIdentifier::MULTIPLY:
case ReductionIdentifier::AND:
case ReductionIdentifier::OR:
case ReductionIdentifier::EQV:
case ReductionIdentifier::NEQV:
if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
mlir::Type realTy = cplxTy.getElementType();
mlir::Value initRe = builder.createRealConstant(
loc, realTy, getOperationIdentity(redId, loc));
mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
initIm);
}
if (mlir::isa<mlir::FloatType>(type))
return mlir::arith::ConstantOp::create(
builder, loc, type,
builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
if (mlir::isa<fir::LogicalType>(type)) {
mlir::Value intConst = mlir::arith::ConstantOp::create(
builder, loc, builder.getI1Type(),
builder.getIntegerAttr(builder.getI1Type(),
getOperationIdentity(redId, loc)));
return builder.createConvert(loc, type, intConst);
}
return mlir::arith::ConstantOp::create(
builder, loc, type,
builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
case ReductionIdentifier::ID:
case ReductionIdentifier::USER_DEF_OP:
case ReductionIdentifier::SUBTRACT:
TODO(loc, "Reduction of some identifier types is not supported");
}
llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
}
mlir::Value ReductionProcessor::createScalarCombiner(
fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
mlir::Type type, mlir::Value op1, mlir::Value op2) {
mlir::Value reductionOp;
type = fir::unwrapRefType(type);
switch (redId) {
case ReductionIdentifier::MAX:
reductionOp =
getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MIN:
reductionOp =
getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
builder, type, loc, op1, op2);
break;
case ReductionIdentifier::IOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = mlir::arith::OrIOp::create(builder, loc, op1, op2);
break;
case ReductionIdentifier::IEOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = mlir::arith::XOrIOp::create(builder, loc, op1, op2);
break;
case ReductionIdentifier::IAND:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = mlir::arith::AndIOp::create(builder, loc, op1, op2);
break;
case ReductionIdentifier::ADD:
reductionOp =
getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
fir::AddcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MULTIPLY:
reductionOp =
getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
fir::MulcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
mlir::Value andiOp =
mlir::arith::AndIOp::create(builder, loc, op1I1, op2I1);
reductionOp = builder.createConvert(loc, type, andiOp);
break;
}
case ReductionIdentifier::OR: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
mlir::Value oriOp = mlir::arith::OrIOp::create(builder, loc, op1I1, op2I1);
reductionOp = builder.createConvert(loc, type, oriOp);
break;
}
case ReductionIdentifier::EQV: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
mlir::Value cmpiOp = mlir::arith::CmpIOp::create(
builder, loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
reductionOp = builder.createConvert(loc, type, cmpiOp);
break;
}
case ReductionIdentifier::NEQV: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
mlir::Value cmpiOp = mlir::arith::CmpIOp::create(
builder, loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
reductionOp = builder.createConvert(loc, type, cmpiOp);
break;
}
default:
TODO(loc, "Reduction of some intrinsic operators is not supported");
}
return reductionOp;
}
template <typename ParentDeclOpType>
static void genYield(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value yieldedValue) {
if constexpr (std::is_same_v<ParentDeclOpType, mlir::omp::DeclareReductionOp>)
mlir::omp::YieldOp::create(builder, loc, yieldedValue);
else
fir::YieldOp::create(builder, loc, yieldedValue);
}
/// Create reduction combiner region for reduction variables which are boxed
/// arrays
template <typename DeclRedOpType>
static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
ReductionProcessor::ReductionIdentifier redId,
fir::BaseBoxType boxTy, mlir::Value lhs,
mlir::Value rhs) {
fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
fir::unwrapRefType(boxTy.getEleTy()));
fir::HeapType heapTy =
mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
fir::PointerType ptrTy =
mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy());
if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy)
TODO(loc, "Unsupported boxed type in OpenMP reduction");
// load fir.ref<fir.box<...>>
mlir::Value lhsAddr = lhs;
lhs = fir::LoadOp::create(builder, loc, lhs);
rhs = fir::LoadOp::create(builder, loc, rhs);
if ((heapTy || ptrTy) && !seqTy) {
// get box contents (heap pointers)
lhs = fir::BoxAddrOp::create(builder, loc, lhs);
rhs = fir::BoxAddrOp::create(builder, loc, rhs);
mlir::Value lhsValAddr = lhs;
// load heap pointers
lhs = fir::LoadOp::create(builder, loc, lhs);
rhs = fir::LoadOp::create(builder, loc, rhs);
mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy();
mlir::Value result = ReductionProcessor::createScalarCombiner(
builder, loc, redId, eleTy, lhs, rhs);
fir::StoreOp::create(builder, loc, result, lhsValAddr);
genYield<DeclRedOpType>(builder, loc, lhsAddr);
return;
}
// Get ShapeShift with default lower bounds. This makes it possible to use
// unmodified LoopNest's indices with ArrayCoorOp.
fir::ShapeShiftOp shapeShift =
getShapeShift(builder, loc, lhs,
/*cannotHaveNonDefaultLowerBounds=*/false,
/*useDefaultLowerBounds=*/true);
// Iterate over array elements, applying the equivalent scalar reduction:
// F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
// and so no null check is needed here before indexing into the (possibly
// allocatable) arrays.
// A hlfir::elemental here gets inlined with a temporary so create the
// loop nest directly.
// This function already controls all of the code in this region so we
// know this won't miss any opportuinties for clever elemental inlining
hlfir::LoopNest nest = hlfir::genLoopNest(
loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
builder.setInsertionPointToStart(nest.body);
const bool seqIsVolatile = fir::isa_volatile_type(seqTy.getEleTy());
mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy(), seqIsVolatile);
auto lhsEleAddr = fir::ArrayCoorOp::create(
builder, loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
auto rhsEleAddr = fir::ArrayCoorOp::create(
builder, loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
auto lhsEle = fir::LoadOp::create(builder, loc, lhsEleAddr);
auto rhsEle = fir::LoadOp::create(builder, loc, rhsEleAddr);
mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
builder, loc, redId, refTy, lhsEle, rhsEle);
fir::StoreOp::create(builder, loc, scalarReduction, lhsEleAddr);
builder.setInsertionPointAfter(nest.outerOp);
genYield<DeclRedOpType>(builder, loc, lhsAddr);
}
// generate combiner region for reduction operations
template <typename DeclRedOpType>
static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
ReductionProcessor::ReductionIdentifier redId,
mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
bool isByRef) {
ty = fir::unwrapRefType(ty);
if (fir::isa_trivial(ty)) {
mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
mlir::Value result = ReductionProcessor::createScalarCombiner(
builder, loc, redId, ty, lhsLoaded, rhsLoaded);
if (isByRef) {
fir::StoreOp::create(builder, loc, result, lhs);
genYield<DeclRedOpType>(builder, loc, lhs);
} else {
genYield<DeclRedOpType>(builder, loc, result);
}
return;
}
// all arrays should have been boxed
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
genBoxCombiner<DeclRedOpType>(builder, loc, redId, boxTy, lhs, rhs);
return;
}
TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
}
// like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
return seqTy.getEleTy();
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
auto eleTy = fir::unwrapRefType(boxTy.getEleTy());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
return seqTy.getEleTy();
return eleTy;
}
return ty;
}
template <typename OpType>
static void createReductionAllocAndInitRegions(
AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type,
bool isByRef, const Fortran::semantics::Symbol *sym) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
mlir::Block *allocBlock = nullptr;
mlir::Block *initBlock = nullptr;
if (isByRef) {
allocBlock =
builder.createBlock(&reductionDecl.getAllocRegion(),
reductionDecl.getAllocRegion().end(), {}, {});
initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
reductionDecl.getInitializerRegion().end(),
{type, type}, {loc, loc});
} else {
initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
reductionDecl.getInitializerRegion().end(),
{type}, {loc});
}
mlir::Type ty = fir::unwrapRefType(type);
builder.setInsertionPointToEnd(initBlock);
mlir::Value initValue =
isByRef ? genInitValueCB(builder, loc, ty, initBlock->getArgument(0),
initBlock->getArgument(1))
: genInitValueCB(builder, loc, ty, initBlock->getArgument(0),
mlir::Value{});
if (isByRef) {
populateByRefInitAndCleanupRegions(
converter, loc, type, initValue, initBlock,
reductionDecl.getInitializerAllocArg(),
reductionDecl.getInitializerMoldArg(), reductionDecl.getCleanupRegion(),
DeclOperationKind::Reduction, sym,
/*cannotHaveLowerBounds=*/false,
/*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
}
if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
if (isByRef) {
// alloc region
builder.setInsertionPointToEnd(allocBlock);
mlir::Value alloca = fir::AllocaOp::create(builder, loc, ty);
yield(alloca);
return;
}
// by val
yield(initValue);
return;
}
assert(isByRef && "passing non-trivial types by val is unsupported");
// alloc region
builder.setInsertionPointToEnd(allocBlock);
mlir::Value boxAlloca = fir::AllocaOp::create(builder, loc, ty);
yield(boxAlloca);
}
template <typename DeclareRedType>
DeclareRedType ReductionProcessor::createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
mlir::Type type, mlir::Location loc, bool isByRef,
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB,
const semantics::Symbol *sym) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();
assert(!reductionOpName.empty());
auto decl = module.lookupSymbol<DeclareRedType>(reductionOpName);
if (decl)
return decl;
mlir::OpBuilder modBuilder(module.getBodyRegion());
mlir::Type valTy = fir::unwrapRefType(type);
// For by-ref reductions, we want to keep track of the
// boxed/referenced/allocated type. For example, for a `real, allocatable`
// variable, `real` should be stored.
mlir::TypeAttr boxedTyAttr{};
mlir::Type boxedTy;
if (isByRef) {
boxedTy = fir::unwrapPassByRefType(valTy);
boxedTyAttr = mlir::TypeAttr::get(boxedTy);
// For character types that are not already references, we need to wrap
// them in a reference type for by-ref reductions.
if (fir::isa_char(valTy) && !fir::isa_ref_type(type)) {
type = fir::ReferenceType::get(valTy);
}
} else
type = valTy;
decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type,
boxedTyAttr);
createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type,
isByRef, sym);
builder.createBlock(&decl.getReductionRegion(),
decl.getReductionRegion().end(), {type, type},
{loc, loc});
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
genCombinerCB(builder, loc, type, op1, op2, isByRef);
if (isByRef && fir::isa_box_type(valTy)) {
mlir::Region &dataPtrPtrRegion = decl.getDataPtrPtrRegion();
mlir::Block &dataAddrBlock = *builder.createBlock(
&dataPtrPtrRegion, dataPtrPtrRegion.end(), {type}, {loc});
builder.setInsertionPointToEnd(&dataAddrBlock);
mlir::Value boxRefOperand = dataAddrBlock.getArgument(0);
mlir::Value baseAddrOffset = fir::BoxOffsetOp::create(
builder, loc, boxRefOperand, fir::BoxFieldAttr::base_addr);
genYield<DeclareRedType>(builder, loc, baseAddrOffset);
}
return decl;
}
template <typename OpType>
OpType ReductionProcessor::createDeclareReduction(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef) {
auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value /*moldArg*/,
mlir::Value /*privArg*/) {
mlir::Type ty = fir::unwrapRefType(type);
mlir::Value initValue = ReductionProcessor::getReductionInitValue(
loc, unwrapSeqOrBoxedType(ty), redId, builder);
return initValue;
};
auto genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value op1, mlir::Value op2,
bool isByRef) {
genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
};
return createDeclareReductionHelper<OpType>(converter, reductionOpName, type,
loc, isByRef, genCombinerCB,
genInitValueCB);
}
bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
if (forceByrefReduction)
return true;
// Non-trivial, non-derived types (e.g., boxes, arrays) must be by-ref.
// Derived types must also be by-ref because user-defined combiners
// operate on components via side-effects, not by producing a whole value.
if (!fir::isa_trivial(fir::unwrapRefType(reductionType)))
return true;
return false;
}
bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) {
if (forceByrefReduction)
return true;
if (auto declare =
mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
reductionVar = declare.getMemref();
return doReductionByRef(reductionVar.getType());
}
template <typename OpType, typename RedOperatorListTy>
bool ReductionProcessor::processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const RedOperatorListTy &redOperatorList,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
if constexpr (std::is_same_v<RedOperatorListTy,
omp::clause::ReductionOperatorList>) {
// For OpenMP reduction clauses, check if the reduction operator is
// supported.
assert(redOperatorList.size() == 1 && "Expecting single operator");
const Fortran::lower::omp::clause::ReductionOperator &redOperator =
redOperatorList.front();
if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
if (const auto *reductionIntrinsic =
std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
// If not an intrinsic is has to be a custom reduction op, and should
// be available in the module.
semantics::Symbol *sym = reductionIntrinsic->v.sym();
mlir::ModuleOp module = builder.getModule();
auto decl = module.lookupSymbol<OpType>(getRealName(sym).ToString());
if (!decl)
return false;
}
} else {
return false;
}
}
}
// Reduction variable processing common to both intrinsic operators and
// procedure designators
mlir::OpBuilder::InsertPoint dcIP;
constexpr bool isDoConcurrent =
std::is_same_v<OpType, fir::DeclareReductionOp>;
if (isDoConcurrent) {
dcIP = builder.saveInsertionPoint();
builder.setInsertionPoint(
builder.getRegion().getParentOfType<fir::DoConcurrentOp>());
}
for (const semantics::Symbol *symbol : reductionSymbols) {
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
mlir::Type eleType;
auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
if (refType)
eleType = refType.getEleTy();
else
eleType = symVal.getType();
// all arrays must be boxed so that we have convenient access to all the
// information needed to iterate over the array
if (mlir::isa<fir::SequenceType>(eleType)) {
// For Host associated symbols, use `SymbolBox` instead
lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
hlfir::Entity entity{symBox.getAddr()};
entity = genVariableBox(currentLocation, builder, entity);
mlir::Value box = entity.getBase();
// Always pass the box by reference so that the OpenMP dialect
// verifiers don't need to know anything about fir.box
auto alloca =
fir::AllocaOp::create(builder, currentLocation, box.getType());
fir::StoreOp::create(builder, currentLocation, box, alloca);
symVal = alloca;
} else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
// boxed arrays are passed as values not by reference. Unfortunately,
// we can't pass a box by value to omp.redution_declare, so turn it
// into a reference
auto oldIP = builder.saveInsertionPoint();
builder.setInsertionPointToStart(builder.getAllocaBlock());
auto alloca =
fir::AllocaOp::create(builder, currentLocation, symVal.getType());
builder.restoreInsertionPoint(oldIP);
fir::StoreOp::create(builder, currentLocation, symVal, alloca);
symVal = alloca;
}
// this isn't the same as the by-val and by-ref passing later in the
// pipeline. Both styles assume that the variable is a reference at
// this point
assert(fir::isa_ref_type(symVal.getType()) &&
"reduction input var is passed by reference");
mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType());
const bool symIsVolatile = fir::isa_volatile_type(symVal.getType());
mlir::Type refTy = fir::ReferenceType::get(elementType, symIsVolatile);
reductionVars.push_back(
builder.createConvert(currentLocation, refTy, symVal));
reduceVarByRef.push_back(doReductionByRef(symVal));
}
unsigned idx = 0;
for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
const auto &kindMap = builder.getKindMap();
std::string reductionName;
ReductionIdentifier redId;
if constexpr (std::is_same_v<RedOperatorListTy,
omp::clause::ReductionOperatorList>) {
const Fortran::lower::omp::clause::ReductionOperator &redOperator =
redOperatorList.front();
if (const auto &redDefinedOp =
std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
redDefinedOp->u)};
redId = getReductionType(intrinsicOp);
switch (redId) {
case ReductionIdentifier::ADD:
case ReductionIdentifier::MULTIPLY:
case ReductionIdentifier::AND:
case ReductionIdentifier::EQV:
case ReductionIdentifier::OR:
case ReductionIdentifier::NEQV:
break;
default:
TODO(currentLocation,
"Reduction of some intrinsic operators is not supported");
break;
}
reductionName = getReductionName(redId, kindMap, redType, isByRef);
// If a user-defined declare reduction already exists for this
// operator+type, reuse it instead of generating a new one
// (which would fail for non-predefined types like derived types).
mlir::ModuleOp module = builder.getModule();
if (auto existingDecl = module.lookupSymbol<OpType>(reductionName)) {
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
builder.getContext(), existingDecl.getSymName()));
++idx;
continue;
}
} else if (const auto *reductionIntrinsic =
std::get_if<omp::clause::ProcedureDesignator>(
&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
// Custom reductions we can just add to the symbols without
// generating the declare reduction op.
semantics::Symbol *sym = reductionIntrinsic->v.sym();
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
builder.getContext(), sym->name().ToString()));
++idx;
continue;
}
redId = getReductionType(*reductionIntrinsic);
reductionName =
getReductionName(getRealName(*reductionIntrinsic).ToString(),
kindMap, redType, isByRef);
} else {
TODO(currentLocation, "Unexpected reduction type");
}
} else {
// `do concurrent` reductions
redId = getReductionType(redOperatorList[idx]);
reductionName = getReductionName(redId, kindMap, redType, isByRef);
}
OpType decl = createDeclareReduction<OpType>(
converter, reductionName, redId, redType, currentLocation, isByRef);
reductionDeclSymbols.push_back(
mlir::SymbolRefAttr::get(builder.getContext(), decl.getSymName()));
++idx;
}
if (isDoConcurrent)
builder.restoreInsertionPoint(dcIP);
return true;
}
const semantics::SourceName
ReductionProcessor::getRealName(const semantics::Symbol *symbol) {
return symbol->GetUltimate().name();
}
const semantics::SourceName
ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
return getRealName(pd.v.sym());
}
int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
mlir::Location loc) {
switch (redId) {
case ReductionIdentifier::ADD:
case ReductionIdentifier::OR:
case ReductionIdentifier::NEQV:
return 0;
case ReductionIdentifier::MULTIPLY:
case ReductionIdentifier::AND:
case ReductionIdentifier::EQV:
return 1;
default:
TODO(loc, "Reduction of some intrinsic operators is not supported");
}
}
} // namespace omp
} // namespace lower
} // namespace Fortran