llvm-project/flang/lib/Lower/Support/ReductionProcessor.cpp
Jan Leyonberg 3e86f05621
[OpenMP][flang] Lowering of OpenMP custom reductions to MLIR (#168417)
This patch add support for lowering of custom reductions to MLIR. It
also enhances the capability of the pass to automatically mark functions
as "declare target" by traversing custom reduction initializers and
combiners.
2025-11-24 16:00:46 -05:00

841 lines
33 KiB
C++

//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
//
//===----------------------------------------------------------------------===//
#include "flang/Lower/Support/ReductionProcessor.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include "flang/Lower/Support/PrivateReductionUtils.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/Complex.h"
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>
static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
llvm::cl::desc("Pass all reduction arguments by reference"),
llvm::cl::Hidden);
using ReductionModifier =
Fortran::lower::omp::clause::Reduction::ReductionModifier;
namespace Fortran {
namespace lower {
namespace omp {
// explicit template declarations
template bool ReductionProcessor::processReductionArguments<
mlir::omp::DeclareReductionOp, omp::clause::ReductionOperatorList>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::ReductionOperatorList &redOperatorList,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
template bool ReductionProcessor::processReductionArguments<
fir::DeclareReductionOp, llvm::SmallVector<fir::ReduceOperationEnum>>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const llvm::SmallVector<fir::ReduceOperationEnum> &redOperatorList,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
template mlir::omp::DeclareReductionOp
ReductionProcessor::createDeclareReduction<mlir::omp::DeclareReductionOp>(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef);
template fir::DeclareReductionOp
ReductionProcessor::createDeclareReduction<fir::DeclareReductionOp>(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef);
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
getRealName(pd.v.sym()).ToString())
.Case("max", ReductionIdentifier::MAX)
.Case("min", ReductionIdentifier::MIN)
.Case("iand", ReductionIdentifier::IAND)
.Case("ior", ReductionIdentifier::IOR)
.Case("ieor", ReductionIdentifier::IEOR)
.Default(std::nullopt);
assert(redType && "Invalid Reduction");
return *redType;
}
ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
switch (intrinsicOp) {
case omp::clause::DefinedOperator::IntrinsicOperator::Add:
return ReductionIdentifier::ADD;
case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
return ReductionIdentifier::SUBTRACT;
case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
return ReductionIdentifier::MULTIPLY;
case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return ReductionIdentifier::AND;
case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return ReductionIdentifier::EQV;
case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return ReductionIdentifier::OR;
case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return ReductionIdentifier::NEQV;
default:
llvm_unreachable("unexpected intrinsic operator in reduction");
}
}
ReductionProcessor::ReductionIdentifier
ReductionProcessor::getReductionType(const fir::ReduceOperationEnum &redOp) {
switch (redOp) {
case fir::ReduceOperationEnum::Add:
return ReductionIdentifier::ADD;
case fir::ReduceOperationEnum::Multiply:
return ReductionIdentifier::MULTIPLY;
case fir::ReduceOperationEnum::AND:
return ReductionIdentifier::AND;
case fir::ReduceOperationEnum::OR:
return ReductionIdentifier::OR;
case fir::ReduceOperationEnum::EQV:
return ReductionIdentifier::EQV;
case fir::ReduceOperationEnum::NEQV:
return ReductionIdentifier::NEQV;
case fir::ReduceOperationEnum::IAND:
return ReductionIdentifier::IAND;
case fir::ReduceOperationEnum::IEOR:
return ReductionIdentifier::IEOR;
case fir::ReduceOperationEnum::IOR:
return ReductionIdentifier::IOR;
case fir::ReduceOperationEnum::MAX:
return ReductionIdentifier::MAX;
case fir::ReduceOperationEnum::MIN:
return ReductionIdentifier::MIN;
}
llvm_unreachable("Unhandled ReductionIdentifier case");
}
bool ReductionProcessor::supportedIntrinsicProcReduction(
const omp::clause::ProcedureDesignator &pd) {
semantics::Symbol *sym = pd.v.sym();
if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC))
return false;
auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
.Case("max", true)
.Case("min", true)
.Case("iand", true)
.Case("ior", true)
.Case("ieor", true)
.Default(false);
return redType;
}
std::string
ReductionProcessor::getReductionName(llvm::StringRef name,
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef) {
ty = fir::unwrapRefType(ty);
// extra string to distinguish reduction functions for variables passed by
// reference
llvm::StringRef byrefAddition{""};
if (isByRef)
byrefAddition = "_byref";
return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
}
std::string
ReductionProcessor::getReductionName(ReductionIdentifier redId,
const fir::KindMapping &kindMap,
mlir::Type ty, bool isByRef) {
std::string reductionName;
switch (redId) {
case ReductionIdentifier::ADD:
reductionName = "add_reduction";
break;
case ReductionIdentifier::MULTIPLY:
reductionName = "multiply_reduction";
break;
case ReductionIdentifier::AND:
reductionName = "and_reduction";
break;
case ReductionIdentifier::EQV:
reductionName = "eqv_reduction";
break;
case ReductionIdentifier::OR:
reductionName = "or_reduction";
break;
case ReductionIdentifier::NEQV:
reductionName = "neqv_reduction";
break;
default:
reductionName = "other_reduction";
break;
}
return getReductionName(reductionName, kindMap, ty, isByRef);
}
mlir::Value
ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
ReductionIdentifier redId,
fir::FirOpBuilder &builder) {
type = fir::unwrapRefType(type);
if (!fir::isa_integer(type) && !fir::isa_real(type) &&
!fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
TODO(loc, "Reduction of some types is not supported");
switch (redId) {
case ReductionIdentifier::MAX: {
if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
}
unsigned bits = type.getIntOrFloatBitWidth();
int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, minInt);
}
case ReductionIdentifier::MIN: {
if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
}
unsigned bits = type.getIntOrFloatBitWidth();
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, maxInt);
}
case ReductionIdentifier::IOR: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, zeroInt);
}
case ReductionIdentifier::IEOR: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, zeroInt);
}
case ReductionIdentifier::IAND: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, allOnInt);
}
case ReductionIdentifier::ADD:
case ReductionIdentifier::MULTIPLY:
case ReductionIdentifier::AND:
case ReductionIdentifier::OR:
case ReductionIdentifier::EQV:
case ReductionIdentifier::NEQV:
if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
mlir::Type realTy = cplxTy.getElementType();
mlir::Value initRe = builder.createRealConstant(
loc, realTy, getOperationIdentity(redId, loc));
mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
initIm);
}
if (mlir::isa<mlir::FloatType>(type))
return mlir::arith::ConstantOp::create(
builder, loc, type,
builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
if (mlir::isa<fir::LogicalType>(type)) {
mlir::Value intConst = mlir::arith::ConstantOp::create(
builder, loc, builder.getI1Type(),
builder.getIntegerAttr(builder.getI1Type(),
getOperationIdentity(redId, loc)));
return builder.createConvert(loc, type, intConst);
}
return mlir::arith::ConstantOp::create(
builder, loc, type,
builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
case ReductionIdentifier::ID:
case ReductionIdentifier::USER_DEF_OP:
case ReductionIdentifier::SUBTRACT:
TODO(loc, "Reduction of some identifier types is not supported");
}
llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
}
mlir::Value ReductionProcessor::createScalarCombiner(
fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
mlir::Type type, mlir::Value op1, mlir::Value op2) {
mlir::Value reductionOp;
type = fir::unwrapRefType(type);
switch (redId) {
case ReductionIdentifier::MAX:
reductionOp =
getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MIN:
reductionOp =
getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
builder, type, loc, op1, op2);
break;
case ReductionIdentifier::IOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = mlir::arith::OrIOp::create(builder, loc, op1, op2);
break;
case ReductionIdentifier::IEOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = mlir::arith::XOrIOp::create(builder, loc, op1, op2);
break;
case ReductionIdentifier::IAND:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = mlir::arith::AndIOp::create(builder, loc, op1, op2);
break;
case ReductionIdentifier::ADD:
reductionOp =
getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
fir::AddcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::MULTIPLY:
reductionOp =
getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
fir::MulcOp>(builder, type, loc, op1, op2);
break;
case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
mlir::Value andiOp =
mlir::arith::AndIOp::create(builder, loc, op1I1, op2I1);
reductionOp = builder.createConvert(loc, type, andiOp);
break;
}
case ReductionIdentifier::OR: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
mlir::Value oriOp = mlir::arith::OrIOp::create(builder, loc, op1I1, op2I1);
reductionOp = builder.createConvert(loc, type, oriOp);
break;
}
case ReductionIdentifier::EQV: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
mlir::Value cmpiOp = mlir::arith::CmpIOp::create(
builder, loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
reductionOp = builder.createConvert(loc, type, cmpiOp);
break;
}
case ReductionIdentifier::NEQV: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
mlir::Value cmpiOp = mlir::arith::CmpIOp::create(
builder, loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
reductionOp = builder.createConvert(loc, type, cmpiOp);
break;
}
default:
TODO(loc, "Reduction of some intrinsic operators is not supported");
}
return reductionOp;
}
template <typename ParentDeclOpType>
static void genYield(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value yieldedValue) {
if constexpr (std::is_same_v<ParentDeclOpType, mlir::omp::DeclareReductionOp>)
mlir::omp::YieldOp::create(builder, loc, yieldedValue);
else
fir::YieldOp::create(builder, loc, yieldedValue);
}
/// Create reduction combiner region for reduction variables which are boxed
/// arrays
template <typename DeclRedOpType>
static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
ReductionProcessor::ReductionIdentifier redId,
fir::BaseBoxType boxTy, mlir::Value lhs,
mlir::Value rhs) {
fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
fir::unwrapRefType(boxTy.getEleTy()));
fir::HeapType heapTy =
mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
fir::PointerType ptrTy =
mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy());
if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy)
TODO(loc, "Unsupported boxed type in OpenMP reduction");
// load fir.ref<fir.box<...>>
mlir::Value lhsAddr = lhs;
lhs = fir::LoadOp::create(builder, loc, lhs);
rhs = fir::LoadOp::create(builder, loc, rhs);
if ((heapTy || ptrTy) && !seqTy) {
// get box contents (heap pointers)
lhs = fir::BoxAddrOp::create(builder, loc, lhs);
rhs = fir::BoxAddrOp::create(builder, loc, rhs);
mlir::Value lhsValAddr = lhs;
// load heap pointers
lhs = fir::LoadOp::create(builder, loc, lhs);
rhs = fir::LoadOp::create(builder, loc, rhs);
mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy();
mlir::Value result = ReductionProcessor::createScalarCombiner(
builder, loc, redId, eleTy, lhs, rhs);
fir::StoreOp::create(builder, loc, result, lhsValAddr);
genYield<DeclRedOpType>(builder, loc, lhsAddr);
return;
}
// Get ShapeShift with default lower bounds. This makes it possible to use
// unmodified LoopNest's indices with ArrayCoorOp.
fir::ShapeShiftOp shapeShift =
getShapeShift(builder, loc, lhs,
/*cannotHaveNonDefaultLowerBounds=*/false,
/*useDefaultLowerBounds=*/true);
// Iterate over array elements, applying the equivalent scalar reduction:
// F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
// and so no null check is needed here before indexing into the (possibly
// allocatable) arrays.
// A hlfir::elemental here gets inlined with a temporary so create the
// loop nest directly.
// This function already controls all of the code in this region so we
// know this won't miss any opportuinties for clever elemental inlining
hlfir::LoopNest nest = hlfir::genLoopNest(
loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
builder.setInsertionPointToStart(nest.body);
const bool seqIsVolatile = fir::isa_volatile_type(seqTy.getEleTy());
mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy(), seqIsVolatile);
auto lhsEleAddr = fir::ArrayCoorOp::create(
builder, loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
auto rhsEleAddr = fir::ArrayCoorOp::create(
builder, loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
auto lhsEle = fir::LoadOp::create(builder, loc, lhsEleAddr);
auto rhsEle = fir::LoadOp::create(builder, loc, rhsEleAddr);
mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
builder, loc, redId, refTy, lhsEle, rhsEle);
fir::StoreOp::create(builder, loc, scalarReduction, lhsEleAddr);
builder.setInsertionPointAfter(nest.outerOp);
genYield<DeclRedOpType>(builder, loc, lhsAddr);
}
// generate combiner region for reduction operations
template <typename DeclRedOpType>
static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
ReductionProcessor::ReductionIdentifier redId,
mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
bool isByRef) {
ty = fir::unwrapRefType(ty);
if (fir::isa_trivial(ty)) {
mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
mlir::Value result = ReductionProcessor::createScalarCombiner(
builder, loc, redId, ty, lhsLoaded, rhsLoaded);
if (isByRef) {
fir::StoreOp::create(builder, loc, result, lhs);
genYield<DeclRedOpType>(builder, loc, lhs);
} else {
genYield<DeclRedOpType>(builder, loc, result);
}
return;
}
// all arrays should have been boxed
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
genBoxCombiner<DeclRedOpType>(builder, loc, redId, boxTy, lhs, rhs);
return;
}
TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
}
// like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
return seqTy.getEleTy();
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
auto eleTy = fir::unwrapRefType(boxTy.getEleTy());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
return seqTy.getEleTy();
return eleTy;
}
return ty;
}
template <typename OpType>
static void createReductionAllocAndInitRegions(
AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
ReductionProcessor::GenInitValueCBTy genInitValueCB, mlir::Type type,
bool isByRef) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
mlir::Block *allocBlock = nullptr;
mlir::Block *initBlock = nullptr;
if (isByRef) {
allocBlock =
builder.createBlock(&reductionDecl.getAllocRegion(),
reductionDecl.getAllocRegion().end(), {}, {});
initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
reductionDecl.getInitializerRegion().end(),
{type, type}, {loc, loc});
} else {
initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
reductionDecl.getInitializerRegion().end(),
{type}, {loc});
}
mlir::Type ty = fir::unwrapRefType(type);
builder.setInsertionPointToEnd(initBlock);
mlir::Value initValue =
genInitValueCB(builder, loc, ty, initBlock->getArgument(0));
if (isByRef) {
populateByRefInitAndCleanupRegions(
converter, loc, type, initValue, initBlock,
reductionDecl.getInitializerAllocArg(),
reductionDecl.getInitializerMoldArg(), reductionDecl.getCleanupRegion(),
DeclOperationKind::Reduction, /*sym=*/nullptr,
/*cannotHaveLowerBounds=*/false,
/*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
}
if (fir::isa_trivial(ty) || fir::isa_derived(ty)) {
if (isByRef) {
// alloc region
builder.setInsertionPointToEnd(allocBlock);
mlir::Value alloca = fir::AllocaOp::create(builder, loc, ty);
yield(alloca);
return;
}
// by val
yield(initValue);
return;
}
assert(isByRef && "passing non-trivial types by val is unsupported");
// alloc region
builder.setInsertionPointToEnd(allocBlock);
mlir::Value boxAlloca = fir::AllocaOp::create(builder, loc, ty);
yield(boxAlloca);
}
template <typename DeclareRedType>
DeclareRedType ReductionProcessor::createDeclareReductionHelper(
AbstractConverter &converter, llvm::StringRef reductionOpName,
mlir::Type type, mlir::Location loc, bool isByRef,
GenCombinerCBTy genCombinerCB, GenInitValueCBTy genInitValueCB) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertionGuard guard(builder);
mlir::ModuleOp module = builder.getModule();
assert(!reductionOpName.empty());
auto decl = module.lookupSymbol<DeclareRedType>(reductionOpName);
if (decl)
return decl;
mlir::OpBuilder modBuilder(module.getBodyRegion());
mlir::Type valTy = fir::unwrapRefType(type);
if (!isByRef)
type = valTy;
decl = DeclareRedType::create(modBuilder, loc, reductionOpName, type);
createReductionAllocAndInitRegions(converter, loc, decl, genInitValueCB, type,
isByRef);
builder.createBlock(&decl.getReductionRegion(),
decl.getReductionRegion().end(), {type, type},
{loc, loc});
builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
genCombinerCB(builder, loc, type, op1, op2, isByRef);
return decl;
}
template <typename OpType>
OpType ReductionProcessor::createDeclareReduction(
AbstractConverter &converter, llvm::StringRef reductionOpName,
const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
bool isByRef) {
auto genInitValueCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value val) {
mlir::Type ty = fir::unwrapRefType(type);
mlir::Value initValue = ReductionProcessor::getReductionInitValue(
loc, unwrapSeqOrBoxedType(ty), redId, builder);
return initValue;
};
auto genCombinerCB = [&](fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type type, mlir::Value op1, mlir::Value op2,
bool isByRef) {
genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
};
return createDeclareReductionHelper<OpType>(converter, reductionOpName, type,
loc, isByRef, genCombinerCB,
genInitValueCB);
}
bool ReductionProcessor::doReductionByRef(mlir::Type reductionType) {
if (forceByrefReduction)
return true;
if (!fir::isa_trivial(fir::unwrapRefType(reductionType)) &&
!fir::isa_derived(fir::unwrapRefType(reductionType)))
return true;
return false;
}
bool ReductionProcessor::doReductionByRef(mlir::Value reductionVar) {
if (forceByrefReduction)
return true;
if (auto declare =
mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
reductionVar = declare.getMemref();
return doReductionByRef(reductionVar.getType());
}
template <typename OpType, typename RedOperatorListTy>
bool ReductionProcessor::processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const RedOperatorListTy &redOperatorList,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
if constexpr (std::is_same_v<RedOperatorListTy,
omp::clause::ReductionOperatorList>) {
// For OpenMP reduction clauses, check if the reduction operator is
// supported.
assert(redOperatorList.size() == 1 && "Expecting single operator");
const Fortran::lower::omp::clause::ReductionOperator &redOperator =
redOperatorList.front();
if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
if (const auto *reductionIntrinsic =
std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
// If not an intrinsic is has to be a custom reduction op, and should
// be available in the module.
semantics::Symbol *sym = reductionIntrinsic->v.sym();
mlir::ModuleOp module = builder.getModule();
auto decl = module.lookupSymbol<OpType>(getRealName(sym).ToString());
if (!decl)
return false;
}
} else {
return false;
}
}
}
// Reduction variable processing common to both intrinsic operators and
// procedure designators
mlir::OpBuilder::InsertPoint dcIP;
constexpr bool isDoConcurrent =
std::is_same_v<OpType, fir::DeclareReductionOp>;
if (isDoConcurrent) {
dcIP = builder.saveInsertionPoint();
builder.setInsertionPoint(
builder.getRegion().getParentOfType<fir::DoConcurrentOp>());
}
for (const semantics::Symbol *symbol : reductionSymbols) {
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
mlir::Type eleType;
auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
if (refType)
eleType = refType.getEleTy();
else
eleType = symVal.getType();
// all arrays must be boxed so that we have convenient access to all the
// information needed to iterate over the array
if (mlir::isa<fir::SequenceType>(eleType)) {
// For Host associated symbols, use `SymbolBox` instead
lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
hlfir::Entity entity{symBox.getAddr()};
entity = genVariableBox(currentLocation, builder, entity);
mlir::Value box = entity.getBase();
// Always pass the box by reference so that the OpenMP dialect
// verifiers don't need to know anything about fir.box
auto alloca =
fir::AllocaOp::create(builder, currentLocation, box.getType());
fir::StoreOp::create(builder, currentLocation, box, alloca);
symVal = alloca;
} else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
// boxed arrays are passed as values not by reference. Unfortunately,
// we can't pass a box by value to omp.redution_declare, so turn it
// into a reference
auto oldIP = builder.saveInsertionPoint();
builder.setInsertionPointToStart(builder.getAllocaBlock());
auto alloca =
fir::AllocaOp::create(builder, currentLocation, symVal.getType());
builder.restoreInsertionPoint(oldIP);
fir::StoreOp::create(builder, currentLocation, symVal, alloca);
symVal = alloca;
}
// this isn't the same as the by-val and by-ref passing later in the
// pipeline. Both styles assume that the variable is a reference at
// this point
assert(fir::isa_ref_type(symVal.getType()) &&
"reduction input var is passed by reference");
mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType());
const bool symIsVolatile = fir::isa_volatile_type(symVal.getType());
mlir::Type refTy = fir::ReferenceType::get(elementType, symIsVolatile);
reductionVars.push_back(
builder.createConvert(currentLocation, refTy, symVal));
reduceVarByRef.push_back(doReductionByRef(symVal));
}
unsigned idx = 0;
for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
const auto &kindMap = builder.getKindMap();
std::string reductionName;
ReductionIdentifier redId;
if constexpr (std::is_same_v<RedOperatorListTy,
omp::clause::ReductionOperatorList>) {
const Fortran::lower::omp::clause::ReductionOperator &redOperator =
redOperatorList.front();
if (const auto &redDefinedOp =
std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
redDefinedOp->u)};
redId = getReductionType(intrinsicOp);
switch (redId) {
case ReductionIdentifier::ADD:
case ReductionIdentifier::MULTIPLY:
case ReductionIdentifier::AND:
case ReductionIdentifier::EQV:
case ReductionIdentifier::OR:
case ReductionIdentifier::NEQV:
break;
default:
TODO(currentLocation,
"Reduction of some intrinsic operators is not supported");
break;
}
reductionName = getReductionName(redId, kindMap, redType, isByRef);
} else if (const auto *reductionIntrinsic =
std::get_if<omp::clause::ProcedureDesignator>(
&redOperator.u)) {
if (!ReductionProcessor::supportedIntrinsicProcReduction(
*reductionIntrinsic)) {
// Custom reductions we can just add to the symbols without
// generating the declare reduction op.
semantics::Symbol *sym = reductionIntrinsic->v.sym();
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
builder.getContext(), sym->name().ToString()));
++idx;
continue;
}
redId = getReductionType(*reductionIntrinsic);
reductionName =
getReductionName(getRealName(*reductionIntrinsic).ToString(),
kindMap, redType, isByRef);
} else {
TODO(currentLocation, "Unexpected reduction type");
}
} else {
// `do concurrent` reductions
redId = getReductionType(redOperatorList[idx]);
reductionName = getReductionName(redId, kindMap, redType, isByRef);
}
OpType decl = createDeclareReduction<OpType>(
converter, reductionName, redId, redType, currentLocation, isByRef);
reductionDeclSymbols.push_back(
mlir::SymbolRefAttr::get(builder.getContext(), decl.getSymName()));
++idx;
}
if (isDoConcurrent)
builder.restoreInsertionPoint(dcIP);
return true;
}
const semantics::SourceName
ReductionProcessor::getRealName(const semantics::Symbol *symbol) {
return symbol->GetUltimate().name();
}
const semantics::SourceName
ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
return getRealName(pd.v.sym());
}
int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
mlir::Location loc) {
switch (redId) {
case ReductionIdentifier::ADD:
case ReductionIdentifier::OR:
case ReductionIdentifier::NEQV:
return 0;
case ReductionIdentifier::MULTIPLY:
case ReductionIdentifier::AND:
case ReductionIdentifier::EQV:
return 1;
default:
TODO(loc, "Reduction of some intrinsic operators is not supported");
}
}
} // namespace omp
} // namespace lower
} // namespace Fortran