[flang][OpenMP][DoConcurrent] Emit declare mapper for records (#179936)

Extends `do concurrent` device support by emitting compiler-generated
declare mapper ops for live-ins whose types are record types and have
allocatable members.
This commit is contained in:
Kareem Ergawy 2026-03-11 13:43:55 +01:00 committed by GitHub
parent 3b8cd6c528
commit acd52a2419
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 202 additions and 115 deletions

View File

@ -13,6 +13,7 @@
namespace fir {
class FirOpBuilder;
class RecordType;
} // namespace fir
namespace Fortran::utils::openmp {
@ -59,6 +60,14 @@ mlir::Value mapTemporaryValue(fir::FirOpBuilder &firOpBuilder,
/// maps.
void cloneOrMapRegionOutsiders(
fir::FirOpBuilder &firOpBuilder, mlir::omp::TargetOp targetOp);
using RecordMemberMapperMangler =
std::function<void(std::string &mapperId, llvm::StringRef memberName)>;
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
fir::FirOpBuilder &firOpBuilder, mlir::Location loc,
fir::RecordType recordType, llvm::StringRef mapperNameStr,
RecordMemberMapperMangler mangler = {});
} // namespace Fortran::utils::openmp
#endif // FORTRAN_UTILS_OPENMP_H_

View File

@ -1586,8 +1586,11 @@ void ClauseProcessor::processMapObjects(
if (!recordType)
return mlir::FlatSymbolRefAttr();
return getOrGenImplicitDefaultDeclareMapper(converter, clauseLocation,
recordType, mapperIdName);
return utils::openmp::getOrGenImplicitDefaultDeclareMapper(
converter.getFirOpBuilder(), clauseLocation, recordType, mapperIdName,
[&](std::string &mapperIdName, llvm::StringRef memberName) {
defaultMangler(converter, mapperIdName, memberName);
});
};
auto getDefaultMapperID =

View File

@ -2857,7 +2857,12 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (auto recordType = mlir::dyn_cast_or_null<fir::RecordType>(
converter.genType(*typeSpec)))
mapperId = getOrGenImplicitDefaultDeclareMapper(
converter, loc, recordType, mapperIdName);
converter.getFirOpBuilder(), loc, recordType,
mapperIdName,
[&](std::string &mapperIdName,
llvm::StringRef memberName) {
defaultMangler(converter, mapperIdName, memberName);
});
} else {
mapperId = mlir::FlatSymbolRefAttr::get(
&converter.getMLIRContext(), mapperIdName);

View File

@ -67,113 +67,6 @@ llvm::cl::opt<bool> treatIndexAsSection(
namespace Fortran {
namespace lower {
namespace omp {
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
lower::AbstractConverter &converter, mlir::Location loc,
fir::RecordType recordType, llvm::StringRef mapperNameStr) {
if (mapperNameStr.empty())
return {};
if (converter.getModuleOp().lookupSymbol(mapperNameStr))
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperNameStr);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
firOpBuilder.setInsertionPointToStart(converter.getModuleOp().getBody());
auto declMapperOp = mlir::omp::DeclareMapperOp::create(
firOpBuilder, loc, mapperNameStr, recordType);
auto &region = declMapperOp.getRegion();
firOpBuilder.createBlock(&region);
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
/*uniq_name=*/"");
const auto genBoundsOps = [&](mlir::Value mapVal,
llvm::SmallVectorImpl<mlir::Value> &bounds) {
fir::ExtendedValue extVal =
hlfir::translateToExtendedValue(mapVal.getLoc(), firOpBuilder,
hlfir::Entity{mapVal},
/*contiguousHint=*/true)
.first;
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
firOpBuilder, info, extVal,
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
};
const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
mlir::Type fieldTy, mlir::Type recType) {
mlir::Value field = fir::FieldIndexOp::create(
firOpBuilder, loc, fir::FieldType::get(recType.getContext()), fieldName,
recType, fir::getTypeParams(rec));
return fir::CoordinateOp::create(
firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
};
llvm::SmallVector<mlir::Value> clauseMapVars;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
llvm::SmallVector<mlir::Value> memberMapOps;
mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
mlir::omp::ClauseMapFlags::from |
mlir::omp::ClauseMapFlags::implicit;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
const auto &memberName = entry.value().first;
const auto &memberType = entry.value().second;
mlir::FlatSymbolRefAttr mapperId;
if (auto recType = mlir::dyn_cast<fir::RecordType>(
fir::getFortranElementType(memberType))) {
std::string mapperIdName =
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
else if (auto *memberSym =
converter.getCurrentScope().FindSymbol(memberName))
mapperIdName = converter.mangleName(mapperIdName, memberSym->owner());
mapperId = getOrGenImplicitDefaultDeclareMapper(converter, loc, recType,
mapperIdName);
}
auto ref =
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
llvm::SmallVector<mlir::Value> bounds;
genBoundsOps(ref, bounds);
mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(
firOpBuilder, loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"",
bounds,
/*members=*/{},
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
/*partialMap=*/false, mapperId);
memberMapOps.emplace_back(mapOp);
memberPlacementIndices.emplace_back(
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
}
llvm::SmallVector<mlir::Value> bounds;
genBoundsOps(declareOp.getOriginalBase(), bounds);
mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
firOpBuilder, loc, declareOp.getOriginalBase(),
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
captureKind, declareOp.getType(0),
/*partialMap=*/true);
clauseMapVars.emplace_back(mapOp);
mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
return mlir::FlatSymbolRefAttr::get(&converter.getMLIRContext(),
mapperNameStr);
}
bool requiresImplicitDefaultDeclareMapper(
const semantics::DerivedTypeSpec &typeSpec) {
// ISO C interoperable types (e.g., c_ptr, c_funptr) must always have implicit
@ -1102,6 +995,15 @@ bool hasIteratorIVReference(
return false;
}
void defaultMangler(Fortran::lower::AbstractConverter &converter,
std::string &mapperIdName, llvm::StringRef memberName) {
if (auto *sym = converter.getCurrentScope().FindSymbol(mapperIdName))
mapperIdName = converter.mangleName(mapperIdName, sym->owner());
else if (auto *memberSym =
converter.getCurrentScope().FindSymbol(memberName.str()))
mapperIdName = converter.mangleName(mapperIdName, memberSym->owner());
}
// Build the array coordinate for an object that uses iterator variables.
// If the object is a section, use the first element of that section
// as the coordinate. Currently only support top-level ArrayRef designators.

View File

@ -139,10 +139,6 @@ mlir::Value createParentSymAndGenIntermediateMaps(
OmpMapParentAndMemberData &parentMemberIndices, llvm::StringRef asFortran,
mlir::omp::ClauseMapFlags mapTypeBits);
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
fir::RecordType recordType, llvm::StringRef mapperNameStr);
bool requiresImplicitDefaultDeclareMapper(
const semantics::DerivedTypeSpec &typeSpec);
@ -216,6 +212,14 @@ bool hasIteratorIVReference(
const omp::Object &object,
const llvm::SmallPtrSetImpl<const Fortran::semantics::Symbol *> &ivSyms);
/// Default name mangler for implicit default mappers.
///
/// \param converter The converter to use for name mangling.
/// \param mapperIdName The name of the mapper to mangle.
/// \param memberName The name of the member to mangle.
void defaultMangler(Fortran::lower::AbstractConverter &converter,
std::string &mapperIdName, llvm::StringRef memberName);
mlir::Value genIteratorCoordinate(Fortran::lower::AbstractConverter &converter,
hlfir::Entity entity,
llvm::ArrayRef<mlir::Value> ivs,

View File

@ -22,6 +22,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
namespace flangomp {
#define GEN_PASS_DEF_DOCONCURRENTCONVERSIONPASS
@ -583,12 +584,43 @@ private:
llvm::SmallVector<mlir::Value> boundsOps;
genBoundsOps(builder, liveIn, rawAddr, boundsOps);
auto asRecordType = [&](mlir::Type eleType) {
return mlir::dyn_cast<fir::RecordType>(
fir::getDerivedType(fir::unwrapRefType(eleType)));
};
fir::RecordType recordType = asRecordType(eleType);
bool requiresImplcitMapper = [&]() {
if (!recordType)
return false;
for (auto [fieldName, fieldType] : recordType.getTypeList()) {
if (fir::isAllocatableType(fieldType))
return true;
if (asRecordType(fieldType))
TODO(liveIn.getLoc(), "Nested record types are not supported yet.");
}
return false;
}();
mlir::FlatSymbolRefAttr mapperId;
if (requiresImplcitMapper) {
std::string mapperIdName =
recordType.getName().str() + llvm::omp::OmpDefaultMapperName;
// TODO Add a mangler callback once nested record types are supported.
mapperId = Fortran::utils::openmp::getOrGenImplicitDefaultDeclareMapper(
builder, liveIn.getLoc(), recordType, mapperIdName);
}
return Fortran::utils::openmp::createMapInfoOp(
builder, liveIn.getLoc(), rawAddr,
/*varPtrPtr=*/{}, name.str(), boundsOps,
/*members=*/{},
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind,
rawAddr.getType());
rawAddr.getType(), /*partialMap=*/false, mapperId);
}
mlir::omp::TargetOp

View File

@ -155,4 +155,108 @@ void cloneOrMapRegionOutsiders(
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
}
}
/// Gets or generates a default declare mapper for a given record type.
///
/// \param firOpBuilder The builder to use for generating the mapper.
/// \param loc The location to use for the generated operations.
/// \param recordType The record type to generate the mapper for.
/// \param mapperNameStr The name of the mapper to generate.
/// \param mangler A function to mangle the mapper name for nested types.
mlir::FlatSymbolRefAttr getOrGenImplicitDefaultDeclareMapper(
fir::FirOpBuilder &firOpBuilder, mlir::Location loc,
fir::RecordType recordType, llvm::StringRef mapperNameStr,
RecordMemberMapperMangler mangler) {
if (mapperNameStr.empty())
return {};
mlir::ModuleOp moduleOp = firOpBuilder.getModule();
if (moduleOp.lookupSymbol(mapperNameStr))
return mlir::FlatSymbolRefAttr::get(
firOpBuilder.getContext(), mapperNameStr);
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
firOpBuilder.setInsertionPointToStart(moduleOp.getBody());
auto declMapperOp = mlir::omp::DeclareMapperOp::create(
firOpBuilder, loc, mapperNameStr, recordType);
auto &region = declMapperOp.getRegion();
firOpBuilder.createBlock(&region);
auto mapperArg = region.addArgument(firOpBuilder.getRefType(recordType), loc);
auto declareOp = hlfir::DeclareOp::create(firOpBuilder, loc, mapperArg,
/*uniq_name=*/"");
const auto genBoundsOps = [&](mlir::Value mapVal,
llvm::SmallVectorImpl<mlir::Value> &bounds) {
fir::ExtendedValue extVal = hlfir::translateToExtendedValue(mapVal.getLoc(),
firOpBuilder, hlfir::Entity{mapVal},
/*contiguousHint=*/true)
.first;
fir::factory::AddrAndBoundsInfo info = fir::factory::getDataOperandBaseAddr(
firOpBuilder, mapVal, /*isOptional=*/false, mapVal.getLoc());
bounds = fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(firOpBuilder, info, extVal,
/*dataExvIsAssumedSize=*/false, mapVal.getLoc());
};
const auto getFieldRef = [&](mlir::Value rec, llvm::StringRef fieldName,
mlir::Type fieldTy, mlir::Type recType) {
mlir::Value field = fir::FieldIndexOp::create(firOpBuilder, loc,
fir::FieldType::get(recType.getContext()), fieldName, recType,
fir::getTypeParams(rec));
return fir::CoordinateOp::create(
firOpBuilder, loc, firOpBuilder.getRefType(fieldTy), rec, field);
};
llvm::SmallVector<mlir::Value> clauseMapVars;
llvm::SmallVector<llvm::SmallVector<int64_t>> memberPlacementIndices;
llvm::SmallVector<mlir::Value> memberMapOps;
mlir::omp::ClauseMapFlags mapFlag = mlir::omp::ClauseMapFlags::to |
mlir::omp::ClauseMapFlags::from | mlir::omp::ClauseMapFlags::implicit;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
for (const auto &entry : llvm::enumerate(recordType.getTypeList())) {
const auto &memberName = entry.value().first;
const auto &memberType = entry.value().second;
mlir::FlatSymbolRefAttr mapperId;
if (auto recType = mlir::dyn_cast<fir::RecordType>(
fir::getFortranElementType(memberType))) {
std::string mapperIdName =
recType.getName().str() + llvm::omp::OmpDefaultMapperName;
mangler(mapperIdName, memberName);
mapperId = getOrGenImplicitDefaultDeclareMapper(
firOpBuilder, loc, recType, mapperIdName, mangler);
}
auto ref =
getFieldRef(declareOp.getBase(), memberName, memberType, recordType);
llvm::SmallVector<mlir::Value> bounds;
genBoundsOps(ref, bounds);
mlir::Value mapOp = Fortran::utils::openmp::createMapInfoOp(firOpBuilder,
loc, ref, /*varPtrPtr=*/mlir::Value{}, /*name=*/"", bounds,
/*members=*/{},
/*membersIndex=*/mlir::ArrayAttr{}, mapFlag, captureKind, ref.getType(),
/*partialMap=*/false, mapperId);
memberMapOps.emplace_back(mapOp);
memberPlacementIndices.emplace_back(
llvm::SmallVector<int64_t>{(int64_t)entry.index()});
}
llvm::SmallVector<mlir::Value> bounds;
genBoundsOps(declareOp.getOriginalBase(), bounds);
mlir::omp::ClauseMapFlags parentMapFlag = mlir::omp::ClauseMapFlags::implicit;
mlir::omp::MapInfoOp mapOp = Fortran::utils::openmp::createMapInfoOp(
firOpBuilder, loc, declareOp.getOriginalBase(),
/*varPtrPtr=*/mlir::Value(), /*name=*/"", bounds, memberMapOps,
firOpBuilder.create2DI64ArrayAttr(memberPlacementIndices), parentMapFlag,
captureKind, declareOp.getType(0),
/*partialMap=*/true);
clauseMapVars.emplace_back(mapOp);
mlir::omp::DeclareMapperInfoOp::create(firOpBuilder, loc, clauseMapVars);
return mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(), mapperNameStr);
}
} // namespace Fortran::utils::openmp

View File

@ -0,0 +1,28 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=device %s -o - \
! RUN: | FileCheck %s
module record_with_alloc_mod
implicit none
public :: record_with_alloc
type record_with_alloc
real, allocatable :: values_(:)
end type
end module record_with_alloc_mod
subroutine random_inputs()
use record_with_alloc_mod, only : record_with_alloc
implicit none
type(record_with_alloc) :: inputs(2)
integer :: i
do concurrent(i=1:10)
inputs(1)%values_ = [1,2,3,4]
end do
end subroutine
! CHECK: omp.declare_mapper @[[MAPPER_NAME:.*record_with_alloc_omp_default_mapper]] : !fir.type<{{.*}}record_with_alloc{{.*}}>
! CHECK: func.func @{{.*}}random_inputs()
! CHECK: %[[ARR_DECL:.*]]:2 = hlfir.declare {{.*}} {uniq_name = "{{.*}}inputs"}
! CHECK: omp.map.info var_ptr(%[[ARR_DECL]]#1 : {{.*}}) {{.*}} mapper(@[[MAPPER_NAME]])