[Flang][OpenMP] Implicitly map nested allocatable components in derived types (#160116)
This PR adds support for nested derived types and their mappers to the MapInfoFinalization pass. - Generalize MapInfoFinalization to add child maps for arbitrarily nested allocatables when a derived object is mapped via declare mapper. - Traverse HLFIR designates rooted at the target block arg and build full coordinate_of chains; append members with correct membersIndex. This fixes #156461.
This commit is contained in:
parent
d2ac21d328
commit
b4f1e0e5b1
@ -701,40 +701,28 @@ class MapInfoFinalizationPass
|
||||
|
||||
auto recordType = mlir::cast<fir::RecordType>(underlyingType);
|
||||
llvm::SmallVector<mlir::Value> newMapOpsForFields;
|
||||
llvm::SmallVector<int64_t> fieldIndicies;
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndexPaths;
|
||||
|
||||
for (auto fieldMemTyPair : recordType.getTypeList()) {
|
||||
auto &field = fieldMemTyPair.first;
|
||||
auto memTy = fieldMemTyPair.second;
|
||||
|
||||
bool shouldMapField =
|
||||
llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
|
||||
if (!fir::isAllocatableType(memTy))
|
||||
return false;
|
||||
|
||||
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
|
||||
if (!designateOp)
|
||||
return false;
|
||||
|
||||
return designateOp.getComponent() &&
|
||||
designateOp.getComponent()->strref() == field;
|
||||
}) != mapVarForwardSlice.end();
|
||||
|
||||
// TODO Handle recursive record types. Adapting
|
||||
// `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
|
||||
// entities might be helpful here.
|
||||
|
||||
if (!shouldMapField)
|
||||
continue;
|
||||
|
||||
int32_t fieldIdx = recordType.getFieldIndex(field);
|
||||
auto appendMemberMap = [&](mlir::Location loc, mlir::Value coordRef,
|
||||
mlir::Type memTy,
|
||||
llvm::ArrayRef<int64_t> indexPath,
|
||||
llvm::StringRef memberName) {
|
||||
// Check if already mapped (index path equality).
|
||||
bool alreadyMapped = [&]() {
|
||||
if (op.getMembersIndexAttr())
|
||||
for (auto indexList : op.getMembersIndexAttr()) {
|
||||
auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
|
||||
if (indexListAttr.size() == 1 &&
|
||||
mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
|
||||
fieldIdx)
|
||||
if (indexListAttr.size() != indexPath.size())
|
||||
continue;
|
||||
bool allEq = true;
|
||||
for (auto [i, attr] : llvm::enumerate(indexListAttr)) {
|
||||
if (mlir::cast<mlir::IntegerAttr>(attr).getInt() !=
|
||||
indexPath[i]) {
|
||||
allEq = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (allEq)
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -742,42 +730,128 @@ class MapInfoFinalizationPass
|
||||
}();
|
||||
|
||||
if (alreadyMapped)
|
||||
continue;
|
||||
return;
|
||||
|
||||
builder.setInsertionPoint(op);
|
||||
fir::IntOrValue idxConst =
|
||||
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
|
||||
auto fieldCoord = fir::CoordinateOp::create(
|
||||
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
|
||||
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
|
||||
fir::factory::AddrAndBoundsInfo info =
|
||||
fir::factory::getDataOperandBaseAddr(
|
||||
builder, fieldCoord, /*isOptional=*/false, op.getLoc());
|
||||
fir::factory::getDataOperandBaseAddr(builder, coordRef,
|
||||
/*isOptional=*/false, loc);
|
||||
llvm::SmallVector<mlir::Value> bounds =
|
||||
fir::factory::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
|
||||
mlir::omp::MapBoundsType>(
|
||||
builder, info,
|
||||
hlfir::translateToExtendedValue(op.getLoc(), builder,
|
||||
hlfir::Entity{fieldCoord})
|
||||
hlfir::translateToExtendedValue(loc, builder,
|
||||
hlfir::Entity{coordRef})
|
||||
.first,
|
||||
/*dataExvIsAssumedSize=*/false, op.getLoc());
|
||||
/*dataExvIsAssumedSize=*/false, loc);
|
||||
|
||||
mlir::omp::MapInfoOp fieldMapOp = mlir::omp::MapInfoOp::create(
|
||||
builder, op.getLoc(), fieldCoord.getResult().getType(),
|
||||
fieldCoord.getResult(),
|
||||
mlir::TypeAttr::get(
|
||||
fir::unwrapRefType(fieldCoord.getResult().getType())),
|
||||
builder, loc, coordRef.getType(), coordRef,
|
||||
mlir::TypeAttr::get(fir::unwrapRefType(coordRef.getType())),
|
||||
op.getMapTypeAttr(),
|
||||
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
|
||||
mlir::omp::VariableCaptureKind::ByRef),
|
||||
/*varPtrPtr=*/mlir::Value{}, /*members=*/mlir::ValueRange{},
|
||||
/*members_index=*/mlir::ArrayAttr{}, bounds,
|
||||
/*mapperId=*/mlir::FlatSymbolRefAttr(),
|
||||
builder.getStringAttr(op.getNameAttr().strref() + "." + field +
|
||||
".implicit_map"),
|
||||
builder.getStringAttr(op.getNameAttr().strref() + "." +
|
||||
memberName + ".implicit_map"),
|
||||
/*partial_map=*/builder.getBoolAttr(false));
|
||||
newMapOpsForFields.emplace_back(fieldMapOp);
|
||||
fieldIndicies.emplace_back(fieldIdx);
|
||||
newMemberIndexPaths.emplace_back(indexPath.begin(), indexPath.end());
|
||||
};
|
||||
|
||||
// 1) Handle direct top-level allocatable fields (existing behavior).
|
||||
for (auto fieldMemTyPair : recordType.getTypeList()) {
|
||||
auto &field = fieldMemTyPair.first;
|
||||
auto memTy = fieldMemTyPair.second;
|
||||
|
||||
if (!fir::isAllocatableType(memTy))
|
||||
continue;
|
||||
|
||||
bool referenced = llvm::any_of(mapVarForwardSlice, [&](auto *opv) {
|
||||
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(opv);
|
||||
return designateOp && designateOp.getComponent() &&
|
||||
designateOp.getComponent()->strref() == field;
|
||||
});
|
||||
if (!referenced)
|
||||
continue;
|
||||
|
||||
int32_t fieldIdx = recordType.getFieldIndex(field);
|
||||
builder.setInsertionPoint(op);
|
||||
fir::IntOrValue idxConst =
|
||||
mlir::IntegerAttr::get(builder.getI32Type(), fieldIdx);
|
||||
auto fieldCoord = fir::CoordinateOp::create(
|
||||
builder, op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
|
||||
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
|
||||
appendMemberMap(op.getLoc(), fieldCoord, memTy, {fieldIdx}, field);
|
||||
}
|
||||
|
||||
// Handle nested allocatable fields along any component chain
|
||||
// referenced in the region via HLFIR designates.
|
||||
for (mlir::Operation *sliceOp : mapVarForwardSlice) {
|
||||
auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
|
||||
if (!designateOp || !designateOp.getComponent())
|
||||
continue;
|
||||
llvm::SmallVector<llvm::StringRef> compPathReversed;
|
||||
compPathReversed.push_back(designateOp.getComponent()->strref());
|
||||
mlir::Value curBase = designateOp.getMemref();
|
||||
bool rootedAtMapArg = false;
|
||||
while (true) {
|
||||
if (auto parentDes = curBase.getDefiningOp<hlfir::DesignateOp>()) {
|
||||
if (!parentDes.getComponent())
|
||||
break;
|
||||
compPathReversed.push_back(parentDes.getComponent()->strref());
|
||||
curBase = parentDes.getMemref();
|
||||
continue;
|
||||
}
|
||||
if (auto decl = curBase.getDefiningOp<hlfir::DeclareOp>()) {
|
||||
if (auto barg =
|
||||
mlir::dyn_cast<mlir::BlockArgument>(decl.getMemref()))
|
||||
rootedAtMapArg = (barg == opBlockArg);
|
||||
} else if (auto blockArg =
|
||||
mlir::dyn_cast_or_null<mlir::BlockArgument>(
|
||||
curBase)) {
|
||||
rootedAtMapArg = (blockArg == opBlockArg);
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (!rootedAtMapArg || compPathReversed.size() < 2)
|
||||
continue;
|
||||
builder.setInsertionPoint(op);
|
||||
llvm::SmallVector<int64_t> indexPath;
|
||||
mlir::Type curTy = underlyingType;
|
||||
mlir::Value coordRef = op.getVarPtr();
|
||||
bool validPath = true;
|
||||
for (llvm::StringRef compName : llvm::reverse(compPathReversed)) {
|
||||
auto recTy = mlir::dyn_cast<fir::RecordType>(curTy);
|
||||
if (!recTy) {
|
||||
validPath = false;
|
||||
break;
|
||||
}
|
||||
int32_t idx = recTy.getFieldIndex(compName);
|
||||
if (idx < 0) {
|
||||
validPath = false;
|
||||
break;
|
||||
}
|
||||
indexPath.push_back(idx);
|
||||
mlir::Type memTy = recTy.getType(idx);
|
||||
fir::IntOrValue idxConst =
|
||||
mlir::IntegerAttr::get(builder.getI32Type(), idx);
|
||||
coordRef = fir::CoordinateOp::create(
|
||||
builder, op.getLoc(), builder.getRefType(memTy), coordRef,
|
||||
llvm::SmallVector<fir::IntOrValue, 1>{idxConst});
|
||||
curTy = memTy;
|
||||
}
|
||||
if (!validPath)
|
||||
continue;
|
||||
if (auto finalRefTy =
|
||||
mlir::dyn_cast<fir::ReferenceType>(coordRef.getType())) {
|
||||
mlir::Type eleTy = finalRefTy.getElementType();
|
||||
if (fir::isAllocatableType(eleTy))
|
||||
appendMemberMap(op.getLoc(), coordRef, eleTy, indexPath,
|
||||
compPathReversed.front());
|
||||
}
|
||||
}
|
||||
|
||||
if (newMapOpsForFields.empty())
|
||||
@ -785,10 +859,8 @@ class MapInfoFinalizationPass
|
||||
|
||||
op.getMembersMutable().append(newMapOpsForFields);
|
||||
llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
|
||||
mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
|
||||
|
||||
if (oldMembersIdxAttr)
|
||||
for (mlir::Attribute indexList : oldMembersIdxAttr) {
|
||||
if (mlir::ArrayAttr oldAttr = op.getMembersIndexAttr())
|
||||
for (mlir::Attribute indexList : oldAttr) {
|
||||
llvm::SmallVector<int64_t> listVec;
|
||||
|
||||
for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
|
||||
@ -796,10 +868,8 @@ class MapInfoFinalizationPass
|
||||
|
||||
newMemberIndices.emplace_back(std::move(listVec));
|
||||
}
|
||||
|
||||
for (int64_t newFieldIdx : fieldIndicies)
|
||||
newMemberIndices.emplace_back(
|
||||
llvm::SmallVector<int64_t>(1, newFieldIdx));
|
||||
for (auto &path : newMemberIndexPaths)
|
||||
newMemberIndices.emplace_back(path);
|
||||
|
||||
op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
|
||||
op.setPartialMap(true);
|
||||
|
||||
@ -6,6 +6,7 @@
|
||||
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-3.f90 -o - | FileCheck %t/omp-declare-mapper-3.f90
|
||||
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-4.f90 -o - | FileCheck %t/omp-declare-mapper-4.f90
|
||||
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 %t/omp-declare-mapper-5.f90 -o - | FileCheck %t/omp-declare-mapper-5.f90
|
||||
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %t/omp-declare-mapper-6.f90 -o - | FileCheck %t/omp-declare-mapper-6.f90
|
||||
|
||||
!--- omp-declare-mapper-1.f90
|
||||
subroutine declare_mapper_1
|
||||
@ -262,3 +263,40 @@ contains
|
||||
!$omp end target
|
||||
end subroutine
|
||||
end program declare_mapper_5
|
||||
|
||||
!--- omp-declare-mapper-6.f90
|
||||
subroutine declare_mapper_nested_parent
|
||||
type :: inner_t
|
||||
real, allocatable :: deep_arr(:)
|
||||
end type inner_t
|
||||
|
||||
type, abstract :: base_t
|
||||
real, allocatable :: base_arr(:)
|
||||
type(inner_t) :: inner
|
||||
end type base_t
|
||||
|
||||
type, extends(base_t) :: real_t
|
||||
real, allocatable :: real_arr(:)
|
||||
end type real_t
|
||||
|
||||
!$omp declare mapper (custommapper : real_t :: t) map(tofrom: t%base_arr, t%real_arr)
|
||||
|
||||
type(real_t) :: r
|
||||
|
||||
allocate(r%base_arr(10))
|
||||
allocate(r%inner%deep_arr(10))
|
||||
allocate(r%real_arr(10))
|
||||
r%base_arr = 1.0
|
||||
r%inner%deep_arr = 4.0
|
||||
r%real_arr = 0.0
|
||||
|
||||
! CHECK: omp.target
|
||||
! Check implicit maps for nested parent and deep nested allocatable payloads
|
||||
! CHECK-DAG: omp.map.info {{.*}} {name = "r.base_arr.implicit_map"}
|
||||
! CHECK-DAG: omp.map.info {{.*}} {name = "r.deep_arr.implicit_map"}
|
||||
! The declared mapper's own allocatable is still mapped implicitly
|
||||
! CHECK-DAG: omp.map.info {{.*}} {name = "r.real_arr.implicit_map"}
|
||||
!$omp target map(mapper(custommapper), tofrom: r)
|
||||
r%real_arr = r%base_arr(1) + r%inner%deep_arr(1)
|
||||
!$omp end target
|
||||
end subroutine declare_mapper_nested_parent
|
||||
|
||||
@ -0,0 +1,43 @@
|
||||
! This test validates that declare mapper for a derived type that extends
|
||||
! a parent type with an allocatable component correctly maps the nested
|
||||
! allocatable payload via the mapper when the whole object is mapped on
|
||||
! target.
|
||||
|
||||
! REQUIRES: flang, amdgpu
|
||||
|
||||
! RUN: %libomptarget-compile-fortran-run-and-check-generic
|
||||
|
||||
program target_declare_mapper_parent_allocatable
|
||||
implicit none
|
||||
|
||||
type, abstract :: base_t
|
||||
real, allocatable :: base_arr(:)
|
||||
end type base_t
|
||||
|
||||
type, extends(base_t) :: real_t
|
||||
real, allocatable :: real_arr(:)
|
||||
end type real_t
|
||||
!$omp declare mapper(custommapper: real_t :: t) map(t%base_arr, t%real_arr)
|
||||
|
||||
type(real_t) :: r
|
||||
integer :: i
|
||||
allocate(r%base_arr(10), source=1.0)
|
||||
allocate(r%real_arr(10), source=1.0)
|
||||
|
||||
!$omp target map(tofrom: r)
|
||||
do i = 1, size(r%base_arr)
|
||||
r%base_arr(i) = 2.0
|
||||
r%real_arr(i) = 3.0
|
||||
r%real_arr(i) = r%base_arr(1)
|
||||
end do
|
||||
!$omp end target
|
||||
|
||||
|
||||
!CHECK: base_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
|
||||
print*, "base_arr: ", r%base_arr
|
||||
!CHECK: real_arr: 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
|
||||
print*, "real_arr: ", r%real_arr
|
||||
|
||||
deallocate(r%real_arr)
|
||||
deallocate(r%base_arr)
|
||||
end program target_declare_mapper_parent_allocatable
|
||||
Loading…
x
Reference in New Issue
Block a user