[flang][OpenMP][DoConcurrent] Emit declare mapper for records (#179936)
Extends `do concurrent` device support by emitting compiler-generated declare mapper ops for live-ins whose types are record types and have allocatable members.
This commit is contained in:
parent
3b8cd6c528
commit
acd52a2419
@ -13,6 +13,7 @@
|
||||
|
||||
namespace fir {
|
||||
class FirOpBuilder;
|
||||
class RecordType;
|
||||
} // namespace fir
|
||||
|
||||
namespace Fortran::utils::openmp {
|
||||
@ -59,6 +60,14 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder,
|
||||
/// maps.
|
||||
void cloneOrMapRegionOutsiders(
|
||||
fir::FirOpBuilder &firOpBuilder, mlir::omp::TargetOp targetOp);
|
||||
|
||||
using RecordMemberMapperMangler =
|
||||
std::function<void(std::string &mapperId, llvm::StringRef memberName)>;
|
||||
|
||||
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
|
||||
fir::FirOpBuilder &firOpBuilder, mlir::Location loc,
|
||||
fir::RecordType recordType, llvm::StringRef mapperNameStr,
|
||||
RecordMemberMapperMangler mangler = {});
|
||||
} // namespace Fortran::utils::openmp
|
||||
|
||||
#endif // FORTRAN_UTILS_OPENMP_H_
|
||||
|
||||
@ -1586,8 +1586,11 @@ void ClauseProcessor::processMapObjects(
|
||||
if (!recordType)
|
||||
return mlir::FlatSymbolRefAttr();
|
||||
|
||||
return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation,
|
||||
recordType, mapperIdName);
|
||||
return utils::openmp::getOrGenImplicitDefaultDeclareMapper(
|
||||
converter.getFirOpBuilder(), clauseLocation, recordType, mapperIdName,
|
||||
[&](std::string &mapperIdName, llvm::StringRef memberName) {
|
||||
defaultMangler(converter, mapperIdName, memberName);
|
||||
});
|
||||
};
|
||||
|
||||
auto getDefaultMapperID =
|
||||
|
||||
@ -2857,7 +2857,12 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
|
||||
if (auto recordType = mlir::dyn_cast_or_null<fir::RecordType>(
|
||||
converter.genType(*typeSpec)))
|
||||
mapperId = getOrGenImplicitDefaultDeclareMapper(
|
||||
converter, loc, recordType, mapperIdName);
|
||||
converter.getFirOpBuilder(), loc, recordType,
|
||||
mapperIdName,
|
||||
[&](std::string &mapperIdName,
|
||||
llvm::StringRef memberName) {
|
||||
defaultMangler(converter, mapperIdName, memberName);
|
||||
});
|
||||
} else {
|
||||
mapperId = mlir::FlatSymbolRefAttr::get(
|
||||
&converter.getMLIRContext(), mapperIdName);
|
||||
|
||||
@ -67,113 +67,6 @@ llvm::cl::opt<bool> treatIndexAsSection(
|
||||
namespace Fortran {
|
||||
namespace lower {
|
||||
namespace omp {
|
||||
|
||||
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
|
||||
lower::AbstractConverter &converter, mlir::Location loc,
|
||||
fir::RecordType recordType, llvm::StringRef mapperNameStr) {
|
||||
if (mapperNameStr.empty())
|
||||
return {};
|
||||
|
||||
if (converter.getModuleOp().lookupSymbol(mapperNameStr))
|
||||
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
|
||||
mapperNameStr);
|
||||
|
||||
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
|
||||
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
|
||||
|
||||
firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
|
||||
auto declMapperOp = mlir::omp::DeclareMapperOp::create(
|
||||
firOpBuilder, loc, mapperNameStr, recordType);
|
||||
auto ®ion = declMapperOp.getRegion();
|
||||
firOpBuilder.createBlock(®ion);
|
||||
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
|
||||
|
||||
auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
|
||||
/*uniq_name=*/"");
|
||||
|
||||
const auto genBoundsOps = [&](mlir::Value mapVal,
|
||||
llvm::SmallVectorImpl<mlir::Value> &bounds) {
|
||||
fir::ExtendedValue extVal =
|
||||
hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
|
||||
hlfir::Entity{mapVal},
|
||||
/*contiguousHint=*/true)
|
||||
.first;
|
||||
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
|
||||
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
|
||||
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
|
||||
mlir::omp::MapBoundsType>(
|
||||
firOpBuilder, info, extVal,
|
||||
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
|
||||
};
|
||||
|
||||
const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
|
||||
mlir::Type fieldTy, mlir::Type recType) {
|
||||
mlir::Value field = fir::FieldIndexOp::create(
|
||||
firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName,
|
||||
recType, fir::getTypeParams(rec));
|
||||
return fir::CoordinateOp::create(
|
||||
firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
|
||||
};
|
||||
|
||||
llvm::SmallVector<mlir::Value> clauseMapVars;
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
|
||||
llvm::SmallVector<mlir::Value> memberMapOps;
|
||||
|
||||
mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
|
||||
mlir::omp::ClauseMapFlags::from |
|
||||
mlir::omp::ClauseMapFlags::implicit;
|
||||
mlir::omp::VariableCaptureKind captureKind =
|
||||
mlir::omp::VariableCaptureKind::ByRef;
|
||||
|
||||
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
|
||||
const auto &memberName = entry.value().first;
|
||||
const auto &memberType = entry.value().second;
|
||||
mlir::FlatSymbolRefAttr mapperId;
|
||||
if (auto recType = mlir::dyn_cast<fir::RecordType>(
|
||||
fir::getFortranElementType(memberType))) {
|
||||
std::string mapperIdName =
|
||||
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
|
||||
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
|
||||
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
|
||||
else if (auto *memberSym =
|
||||
converter.getCurrentScope().FindSymbol(memberName))
|
||||
mapperIdName = converter.mangleName(mapperIdName, memberSym->owner());
|
||||
|
||||
mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType,
|
||||
mapperIdName);
|
||||
}
|
||||
|
||||
auto ref =
|
||||
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
|
||||
llvm::SmallVector<mlir::Value> bounds;
|
||||
genBoundsOps(ref, bounds);
|
||||
mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(
|
||||
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"",
|
||||
bounds,
|
||||
/*members=*/{},
|
||||
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
|
||||
/*partialMap=*/false, mapperId);
|
||||
memberMapOps.emplace_back(mapOp);
|
||||
memberPlacementIndices.emplace_back(
|
||||
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> bounds;
|
||||
genBoundsOps(declareOp.getOriginalBase(), bounds);
|
||||
mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
|
||||
mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
|
||||
firOpBuilder, loc, declareOp.getOriginalBase(),
|
||||
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
|
||||
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
|
||||
captureKind, declareOp.getType(0),
|
||||
/*partialMap=*/true);
|
||||
|
||||
clauseMapVars.emplace_back(mapOp);
|
||||
mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
|
||||
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
|
||||
mapperNameStr);
|
||||
}
|
||||
|
||||
bool requiresImplicitDefaultDeclareMapper(
|
||||
const semantics::DerivedTypeSpec &typeSpec) {
|
||||
// ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit
|
||||
@ -1102,6 +995,15 @@ bool hasIteratorIVReference(
|
||||
return false;
|
||||
}
|
||||
|
||||
void defaultMangler(Fortran::lower::AbstractConverter &converter,
|
||||
std::string &mapperIdName, llvm::StringRef memberName) {
|
||||
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
|
||||
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
|
||||
else if (auto *memberSym =
|
||||
converter.getCurrentScope().FindSymbol(memberName.str()))
|
||||
mapperIdName = converter.mangleName(mapperIdName, memberSym->owner());
|
||||
}
|
||||
|
||||
// Build the array coordinate for an object that uses iterator variables.
|
||||
// If the object is a section, use the first element of that section
|
||||
// as the coordinate. Currently only support top-level ArrayRef designators.
|
||||
|
||||
@ -139,10 +139,6 @@ mlir::Value createParentSymAndGenIntermediateMaps(
|
||||
OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran,
|
||||
mlir::omp::ClauseMapFlags mapTypeBits);
|
||||
|
||||
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
|
||||
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
|
||||
fir::RecordType recordType, llvm::StringRef mapperNameStr);
|
||||
|
||||
bool requiresImplicitDefaultDeclareMapper(
|
||||
const semantics::DerivedTypeSpec &typeSpec);
|
||||
|
||||
@ -216,6 +212,14 @@ bool hasIteratorIVReference(
|
||||
const omp::Object &object,
|
||||
const llvm::SmallPtrSetImpl<const Fortran::semantics::Symbol *> &ivSyms);
|
||||
|
||||
/// Default name mangler for implicit default mappers.
|
||||
///
|
||||
/// \param converter The converter to use for name mangling.
|
||||
/// \param mapperIdName The name of the mapper to mangle.
|
||||
/// \param memberName The name of the member to mangle.
|
||||
void defaultMangler(Fortran::lower::AbstractConverter &converter,
|
||||
std::string &mapperIdName, llvm::StringRef memberName);
|
||||
|
||||
mlir::Value genIteratorCoordinate(Fortran::lower::AbstractConverter &converter,
|
||||
hlfir::Entity entity,
|
||||
llvm::ArrayRef<mlir::Value> ivs,
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Frontend/OpenMP/OMPConstants.h"
|
||||
|
||||
namespace flangomp {
|
||||
#define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS
|
||||
@ -583,12 +584,43 @@ private:
|
||||
llvm::SmallVector<mlir::Value> boundsOps;
|
||||
genBoundsOps(builder, liveIn, rawAddr, boundsOps);
|
||||
|
||||
auto asRecordType = [&](mlir::Type eleType) {
|
||||
return mlir::dyn_cast<fir::RecordType>(
|
||||
fir::getDerivedType(fir::unwrapRefType(eleType)));
|
||||
};
|
||||
|
||||
fir::RecordType recordType = asRecordType(eleType);
|
||||
|
||||
bool requiresImplcitMapper = [&]() {
|
||||
if (!recordType)
|
||||
return false;
|
||||
|
||||
for (auto [fieldName, fieldType] : recordType.getTypeList()) {
|
||||
if (fir::isAllocatableType(fieldType))
|
||||
return true;
|
||||
|
||||
if (asRecordType(fieldType))
|
||||
TODO(liveIn.getLoc(), "Nested record types are not supported yet.");
|
||||
}
|
||||
|
||||
return false;
|
||||
}();
|
||||
|
||||
mlir::FlatSymbolRefAttr mapperId;
|
||||
if (requiresImplcitMapper) {
|
||||
std::string mapperIdName =
|
||||
recordType.getName().str() + llvm::omp::OmpDefaultMapperName;
|
||||
// TODO Add a mangler callback once nested record types are supported.
|
||||
mapperId = Fortran::utils::openmp::getOrGenImplicitDefaultDeclareMapper(
|
||||
builder, liveIn.getLoc(), recordType, mapperIdName);
|
||||
}
|
||||
|
||||
return Fortran::utils::openmp::createMapInfoOp(
|
||||
builder, liveIn.getLoc(), rawAddr,
|
||||
/*varPtrPtr=*/{}, name.str(), boundsOps,
|
||||
/*members=*/{},
|
||||
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind,
|
||||
rawAddr.getType());
|
||||
rawAddr.getType(), /*partialMap=*/false, mapperId);
|
||||
}
|
||||
|
||||
mlir::omp::TargetOp
|
||||
|
||||
@ -155,4 +155,108 @@ void cloneOrMapRegionOutsiders(
|
||||
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets or generates a default declare mapper for a given record type.
|
||||
///
|
||||
/// \param firOpBuilder The builder to use for generating the mapper.
|
||||
/// \param loc The location to use for the generated operations.
|
||||
/// \param recordType The record type to generate the mapper for.
|
||||
/// \param mapperNameStr The name of the mapper to generate.
|
||||
/// \param mangler A function to mangle the mapper name for nested types.
|
||||
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
|
||||
fir::FirOpBuilder &firOpBuilder, mlir::Location loc,
|
||||
fir::RecordType recordType, llvm::StringRef mapperNameStr,
|
||||
RecordMemberMapperMangler mangler) {
|
||||
if (mapperNameStr.empty())
|
||||
return {};
|
||||
|
||||
mlir::ModuleOp moduleOp = firOpBuilder.getModule();
|
||||
if (moduleOp.lookupSymbol(mapperNameStr))
|
||||
return mlir::FlatSymbolRefAttr::get(
|
||||
firOpBuilder.getContext(), mapperNameStr);
|
||||
|
||||
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
|
||||
|
||||
firOpBuilder.setInsertionPointToStart(moduleOp.getBody());
|
||||
auto declMapperOp = mlir::omp::DeclareMapperOp::create(
|
||||
firOpBuilder, loc, mapperNameStr, recordType);
|
||||
auto ®ion = declMapperOp.getRegion();
|
||||
firOpBuilder.createBlock(®ion);
|
||||
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
|
||||
|
||||
auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
|
||||
/*uniq_name=*/"");
|
||||
|
||||
const auto genBoundsOps = [&](mlir::Value mapVal,
|
||||
llvm::SmallVectorImpl<mlir::Value> &bounds) {
|
||||
fir::ExtendedValue extVal = hlfir::translateToExtendedValue(mapVal.getLoc(),
|
||||
firOpBuilder, hlfir::Entity{mapVal},
|
||||
/*contiguousHint=*/true)
|
||||
.first;
|
||||
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
|
||||
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
|
||||
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
|
||||
mlir::omp::MapBoundsType>(firOpBuilder, info, extVal,
|
||||
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
|
||||
};
|
||||
|
||||
const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
|
||||
mlir::Type fieldTy, mlir::Type recType) {
|
||||
mlir::Value field = fir::FieldIndexOp::create(firOpBuilder, loc,
|
||||
fir::FieldType::get(recType.getContext()), fieldName, recType,
|
||||
fir::getTypeParams(rec));
|
||||
return fir::CoordinateOp::create(
|
||||
firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
|
||||
};
|
||||
|
||||
llvm::SmallVector<mlir::Value> clauseMapVars;
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
|
||||
llvm::SmallVector<mlir::Value> memberMapOps;
|
||||
|
||||
mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
|
||||
mlir::omp::ClauseMapFlags::from | mlir::omp::ClauseMapFlags::implicit;
|
||||
mlir::omp::VariableCaptureKind captureKind =
|
||||
mlir::omp::VariableCaptureKind::ByRef;
|
||||
|
||||
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
|
||||
const auto &memberName = entry.value().first;
|
||||
const auto &memberType = entry.value().second;
|
||||
mlir::FlatSymbolRefAttr mapperId;
|
||||
if (auto recType = mlir::dyn_cast<fir::RecordType>(
|
||||
fir::getFortranElementType(memberType))) {
|
||||
std::string mapperIdName =
|
||||
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
|
||||
mangler(mapperIdName, memberName);
|
||||
mapperId = getOrGenImplicitDefaultDeclareMapper(
|
||||
firOpBuilder, loc, recType, mapperIdName, mangler);
|
||||
}
|
||||
|
||||
auto ref =
|
||||
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
|
||||
llvm::SmallVector<mlir::Value> bounds;
|
||||
genBoundsOps(ref, bounds);
|
||||
mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(firOpBuilder,
|
||||
loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds,
|
||||
/*members=*/{},
|
||||
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
|
||||
/*partialMap=*/false, mapperId);
|
||||
memberMapOps.emplace_back(mapOp);
|
||||
memberPlacementIndices.emplace_back(
|
||||
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> bounds;
|
||||
genBoundsOps(declareOp.getOriginalBase(), bounds);
|
||||
mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
|
||||
mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
|
||||
firOpBuilder, loc, declareOp.getOriginalBase(),
|
||||
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
|
||||
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
|
||||
captureKind, declareOp.getType(0),
|
||||
/*partialMap=*/true);
|
||||
|
||||
clauseMapVars.emplace_back(mapOp);
|
||||
mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
|
||||
return mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), mapperNameStr);
|
||||
}
|
||||
} // namespace Fortran::utils::openmp
|
||||
|
||||
28
flang/test/Transforms/DoConcurrent/implicit_mapper.f90
Normal file
28
flang/test/Transforms/DoConcurrent/implicit_mapper.f90
Normal file
@ -0,0 +1,28 @@
|
||||
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=device %s -o - \
|
||||
! RUN: | FileCheck %s
|
||||
|
||||
module record_with_alloc_mod
|
||||
implicit none
|
||||
public :: record_with_alloc
|
||||
|
||||
type record_with_alloc
|
||||
real, allocatable :: values_(:)
|
||||
end type
|
||||
end module record_with_alloc_mod
|
||||
|
||||
subroutine random_inputs()
|
||||
use record_with_alloc_mod, only : record_with_alloc
|
||||
implicit none
|
||||
type(record_with_alloc) :: inputs(2)
|
||||
integer :: i
|
||||
|
||||
do concurrent(i=1:10)
|
||||
inputs(1)%values_ = [1,2,3,4]
|
||||
end do
|
||||
end subroutine
|
||||
|
||||
! CHECK: omp.declare_mapper @[[MAPPER_NAME:.*record_with_alloc_omp_default_mapper]] : !fir.type<{{.*}}record_with_alloc{{.*}}>
|
||||
|
||||
! CHECK: func.func @{{.*}}random_inputs()
|
||||
! CHECK: %[[ARR_DECL:.*]]:2 = hlfir.declare {{.*}} {uniq_name = "{{.*}}inputs"}
|
||||
! CHECK: omp.map.info var_ptr(%[[ARR_DECL]]#1 : {{.*}}) {{.*}} mapper(@[[MAPPER_NAME]])
|
||||
Loading…
x
Reference in New Issue
Block a user