Extends `do concurrent` device support by emitting compiler-generated declare mapper ops for live-ins whose types are record types and have allocatable members.
263 lines
11 KiB
C++
263 lines
11 KiB
C++
//===-- lib/Utisl/OpenMP.cpp ------------------------------------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "flang/Utils/OpenMP.h"
|
|
|
|
#include "flang/Lower/ConvertExprToHLFIR.h"
|
|
#include "flang/Optimizer/Builder/DirectivesCommon.h"
|
|
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
|
#include "flang/Optimizer/Dialect/FIROps.h"
|
|
#include "flang/Optimizer/Dialect/FIRType.h"
|
|
|
|
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
|
|
namespace Fortran::utils::openmp {
|
|
mlir::omp::MapInfoOp createMapInfoOp(mlir::OpBuilder &builder,
|
|
mlir::Location loc, mlir::Value baseAddr, mlir::Value varPtrPtr,
|
|
llvm::StringRef name, llvm::ArrayRef<mlir::Value> bounds,
|
|
llvm::ArrayRef<mlir::Value> members, mlir::ArrayAttr membersIndex,
|
|
mlir::omp::ClauseMapFlags mapType,
|
|
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
|
|
bool partialMap, mlir::FlatSymbolRefAttr mapperId) {
|
|
|
|
if (auto boxTy = llvm::dyn_cast<fir::BaseBoxType>(baseAddr.getType())) {
|
|
baseAddr = fir::BoxAddrOp::create(builder, loc, baseAddr);
|
|
retTy = baseAddr.getType();
|
|
}
|
|
|
|
mlir::TypeAttr varType = mlir::TypeAttr::get(
|
|
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
|
|
|
|
// For types with unknown extents such as <2x?xi32> we discard the incomplete
|
|
// type info and only retain the base type. The correct dimensions are later
|
|
// recovered through the bounds info.
|
|
if (auto seqType = llvm::dyn_cast<fir::SequenceType>(varType.getValue()))
|
|
if (seqType.hasDynamicExtents())
|
|
varType = mlir::TypeAttr::get(seqType.getEleTy());
|
|
|
|
mlir::omp::MapInfoOp op =
|
|
mlir::omp::MapInfoOp::create(builder, loc, retTy, baseAddr, varType,
|
|
builder.getAttr<mlir::omp::ClauseMapFlagsAttr>(mapType),
|
|
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
|
|
varPtrPtr, members, membersIndex, bounds, mapperId,
|
|
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
|
|
return op;
|
|
}
|
|
|
|
mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder,
|
|
mlir::omp::TargetOp targetOp, mlir::Value val, llvm::StringRef name) {
|
|
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
|
|
mlir::Operation *valOp = val.getDefiningOp();
|
|
|
|
if (valOp)
|
|
firOpBuilder.setInsertionPointAfter(valOp);
|
|
else
|
|
// This means val is a block argument
|
|
firOpBuilder.setInsertionPoint(targetOp);
|
|
|
|
auto copyVal = firOpBuilder.createTemporary(val.getLoc(), val.getType());
|
|
firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);
|
|
|
|
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
|
|
firOpBuilder, val, /*isOptional=*/false, val.getLoc());
|
|
llvm::SmallVector<mlir::Value> bounds =
|
|
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
|
|
mlir::omp::MapBoundsType>(firOpBuilder, info,
|
|
hlfir::translateToExtendedValue(
|
|
val.getLoc(), firOpBuilder, hlfir::Entity{val})
|
|
.first,
|
|
/*dataExvIsAssumedSize=*/false, val.getLoc());
|
|
|
|
firOpBuilder.setInsertionPoint(targetOp);
|
|
|
|
mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::implicit;
|
|
mlir::omp::VariableCaptureKind captureKind =
|
|
mlir::omp::VariableCaptureKind::ByRef;
|
|
|
|
mlir::Type eleType = copyVal.getType();
|
|
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(copyVal.getType())) {
|
|
eleType = refType.getElementType();
|
|
}
|
|
|
|
if (fir::isa_trivial(eleType) || fir::isa_char(eleType)) {
|
|
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
|
|
} else if (!fir::isa_builtin_cptr_type(eleType)) {
|
|
mapFlag |= mlir::omp::ClauseMapFlags::to;
|
|
}
|
|
|
|
mlir::Value mapOp = createMapInfoOp(firOpBuilder, copyVal.getLoc(), copyVal,
|
|
/*varPtrPtr=*/mlir::Value{}, name.str(), bounds,
|
|
/*members=*/llvm::SmallVector<mlir::Value>{},
|
|
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind,
|
|
copyVal.getType());
|
|
|
|
auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
|
|
mlir::Region ®ion = targetOp.getRegion();
|
|
|
|
// Get the index of the first non-map argument before modifying mapVars,
|
|
// then append an element to mapVars and an associated entry block
|
|
// argument at that index.
|
|
unsigned insertIndex =
|
|
argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs();
|
|
targetOp.getMapVarsMutable().append(mapOp);
|
|
mlir::Value clonedValArg =
|
|
region.insertArgument(insertIndex, copyVal.getType(), copyVal.getLoc());
|
|
|
|
mlir::Block *entryBlock = ®ion.getBlocks().front();
|
|
firOpBuilder.setInsertionPointToStart(entryBlock);
|
|
auto loadOp =
|
|
fir::LoadOp::create(firOpBuilder, clonedValArg.getLoc(), clonedValArg);
|
|
return loadOp.getResult();
|
|
}
|
|
|
|
void cloneOrMapRegionOutsiders(
|
|
fir::FirOpBuilder &firOpBuilder, mlir::omp::TargetOp targetOp) {
|
|
mlir::Region ®ion = targetOp.getRegion();
|
|
mlir::Block *entryBlock = ®ion.getBlocks().front();
|
|
|
|
llvm::SetVector<mlir::Value> valuesDefinedAbove;
|
|
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
|
|
while (!valuesDefinedAbove.empty()) {
|
|
for (mlir::Value val : valuesDefinedAbove) {
|
|
mlir::Operation *valOp = val.getDefiningOp();
|
|
|
|
// NOTE: We skip BoxDimsOp's as the lesser of two evils is to map the
|
|
// indices separately, as the alternative is to eventually map the Box,
|
|
// which comes with a fairly large overhead comparatively. We could be
|
|
// more robust about this and check using a BackwardsSlice to see if we
|
|
// run the risk of mapping a box.
|
|
if (valOp && mlir::isMemoryEffectFree(valOp) &&
|
|
!mlir::isa<fir::BoxDimsOp>(valOp)) {
|
|
mlir::Operation *clonedOp = valOp->clone();
|
|
entryBlock->push_front(clonedOp);
|
|
|
|
auto replace = [entryBlock](mlir::OpOperand &use) {
|
|
return use.getOwner()->getBlock() == entryBlock;
|
|
};
|
|
|
|
valOp->getResults().replaceUsesWithIf(clonedOp->getResults(), replace);
|
|
valOp->replaceUsesWithIf(clonedOp, replace);
|
|
} else {
|
|
mlir::Value mappedTemp = mapTemporaryValue(firOpBuilder, targetOp, val,
|
|
/*name=*/{});
|
|
val.replaceUsesWithIf(mappedTemp, [entryBlock](mlir::OpOperand &use) {
|
|
return use.getOwner()->getBlock() == entryBlock;
|
|
});
|
|
}
|
|
}
|
|
valuesDefinedAbove.clear();
|
|
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
|
|
}
|
|
}
|
|
|
|
/// Gets or generates a default declare mapper for a given record type.
|
|
///
|
|
/// \param firOpBuilder The builder to use for generating the mapper.
|
|
/// \param loc The location to use for the generated operations.
|
|
/// \param recordType The record type to generate the mapper for.
|
|
/// \param mapperNameStr The name of the mapper to generate.
|
|
/// \param mangler A function to mangle the mapper name for nested types.
|
|
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
|
|
fir::FirOpBuilder &firOpBuilder, mlir::Location loc,
|
|
fir::RecordType recordType, llvm::StringRef mapperNameStr,
|
|
RecordMemberMapperMangler mangler) {
|
|
if (mapperNameStr.empty())
|
|
return {};
|
|
|
|
mlir::ModuleOp moduleOp = firOpBuilder.getModule();
|
|
if (moduleOp.lookupSymbol(mapperNameStr))
|
|
return mlir::FlatSymbolRefAttr::get(
|
|
firOpBuilder.getContext(), mapperNameStr);
|
|
|
|
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
|
|
|
|
firOpBuilder.setInsertionPointToStart(moduleOp.getBody());
|
|
auto declMapperOp = mlir::omp::DeclareMapperOp::create(
|
|
firOpBuilder, loc, mapperNameStr, recordType);
|
|
auto ®ion = declMapperOp.getRegion();
|
|
firOpBuilder.createBlock(®ion);
|
|
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
|
|
|
|
auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
|
|
/*uniq_name=*/"");
|
|
|
|
const auto genBoundsOps = [&](mlir::Value mapVal,
|
|
llvm::SmallVectorImpl<mlir::Value> &bounds) {
|
|
fir::ExtendedValue extVal = hlfir::translateToExtendedValue(mapVal.getLoc(),
|
|
firOpBuilder, hlfir::Entity{mapVal},
|
|
/*contiguousHint=*/true)
|
|
.first;
|
|
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
|
|
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
|
|
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
|
|
mlir::omp::MapBoundsType>(firOpBuilder, info, extVal,
|
|
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
|
|
};
|
|
|
|
const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
|
|
mlir::Type fieldTy, mlir::Type recType) {
|
|
mlir::Value field = fir::FieldIndexOp::create(firOpBuilder, loc,
|
|
fir::FieldType::get(recType.getContext()), fieldName, recType,
|
|
fir::getTypeParams(rec));
|
|
return fir::CoordinateOp::create(
|
|
firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
|
|
};
|
|
|
|
llvm::SmallVector<mlir::Value> clauseMapVars;
|
|
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
|
|
llvm::SmallVector<mlir::Value> memberMapOps;
|
|
|
|
mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
|
|
mlir::omp::ClauseMapFlags::from | mlir::omp::ClauseMapFlags::implicit;
|
|
mlir::omp::VariableCaptureKind captureKind =
|
|
mlir::omp::VariableCaptureKind::ByRef;
|
|
|
|
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
|
|
const auto &memberName = entry.value().first;
|
|
const auto &memberType = entry.value().second;
|
|
mlir::FlatSymbolRefAttr mapperId;
|
|
if (auto recType = mlir::dyn_cast<fir::RecordType>(
|
|
fir::getFortranElementType(memberType))) {
|
|
std::string mapperIdName =
|
|
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
|
|
mangler(mapperIdName, memberName);
|
|
mapperId = getOrGenImplicitDefaultDeclareMapper(
|
|
firOpBuilder, loc, recType, mapperIdName, mangler);
|
|
}
|
|
|
|
auto ref =
|
|
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
|
|
llvm::SmallVector<mlir::Value> bounds;
|
|
genBoundsOps(ref, bounds);
|
|
mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(firOpBuilder,
|
|
loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds,
|
|
/*members=*/{},
|
|
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
|
|
/*partialMap=*/false, mapperId);
|
|
memberMapOps.emplace_back(mapOp);
|
|
memberPlacementIndices.emplace_back(
|
|
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
|
|
}
|
|
|
|
llvm::SmallVector<mlir::Value> bounds;
|
|
genBoundsOps(declareOp.getOriginalBase(), bounds);
|
|
mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
|
|
mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
|
|
firOpBuilder, loc, declareOp.getOriginalBase(),
|
|
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
|
|
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
|
|
captureKind, declareOp.getType(0),
|
|
/*partialMap=*/true);
|
|
|
|
clauseMapVars.emplace_back(mapOp);
|
|
mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
|
|
return mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), mapperNameStr);
|
|
}
|
|
} // namespace Fortran::utils::openmp
|