[flang][cuda] Move cuf.set_allocator_idx after derived-type init (#148936)

Derived type initialization overwrite the component descriptor. Place
the `cuf.set_allocator_idx` after the initialization is performed.
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2025-07-15 13:52:00 -07:00 committed by GitHub
parent 42d2ae1034
commit b4e2272271
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 100 additions and 76 deletions

View File

@ -771,79 +771,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
return builder.create<cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
indices);
if (!cuf::isCUDADeviceContext(builder.getRegion())) {
mlir::Value alloc = builder.create<cuf::AllocOp>(
loc, ty, nm, symNm, dataAttr, lenParams, indices);
if (const auto *details{
ultimateSymbol
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
const Fortran::semantics::DeclTypeSpec *type{details->type()};
const Fortran::semantics::DerivedTypeSpec *derived{
type ? type->AsDerived() : nullptr};
if (derived) {
Fortran::semantics::UltimateComponentIterator components{*derived};
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
llvm::SmallVector<mlir::Value> coordinates;
for (const auto &sym : components) {
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
mlir::Type fieldTy;
std::vector<mlir::Value> coordinates;
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
// Field found in the base record type.
auto fieldName = recTy.getTypeList()[fieldIdx].first;
fieldTy = recTy.getTypeList()[fieldIdx].second;
mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(fieldIndex);
} else {
// Field not found in base record type, search in potential
// record type components.
for (auto component : recTy.getTypeList()) {
if (auto childRecTy =
mlir::dyn_cast<fir::RecordType>(component.second)) {
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
mlir::Value parentFieldIndex =
builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(childRecTy.getContext()),
component.first, recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(parentFieldIndex);
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
mlir::Value childFieldIndex =
builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(fieldTy.getContext()),
fieldName, childRecTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(childFieldIndex);
break;
}
}
}
}
if (coordinates.empty())
TODO(loc, "device resident component in complex derived-type "
"hierarchy");
mlir::Value comp = builder.create<fir::CoordinateOp>(
loc, builder.getRefType(fieldTy), alloc, coordinates);
cuf::DataAttributeAttr dataAttr =
Fortran::lower::translateSymbolCUFDataAttribute(
builder.getContext(), sym);
builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
}
}
}
}
return alloc;
}
if (!cuf::isCUDADeviceContext(builder.getRegion()))
return builder.create<cuf::AllocOp>(loc, ty, nm, symNm, dataAttr,
lenParams, indices);
}
// Let the builder do all the heavy lifting.
@ -857,6 +787,91 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
return res;
}
/// Device allocatable components in a derived-type don't have the correct
/// allocator index in their descriptor when they are created. After
/// initialization, cuf.set_allocator_idx operations are inserted to set the
/// correct allocator index for each device component.
static void
initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &symbol,
Fortran::lower::SymMap &symMap) {
if (const auto *details{
symbol.GetUltimate()
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
const Fortran::semantics::DeclTypeSpec *type{details->type()};
const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
: nullptr};
if (derived) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = converter.getCurrentLocation();
fir::ExtendedValue exv =
converter.getSymbolExtendedValue(symbol.GetUltimate(), &symMap);
auto recTy = mlir::dyn_cast<fir::RecordType>(
fir::unwrapRefType(fir::getBase(exv).getType()));
assert(recTy && "expected fir::RecordType");
llvm::SmallVector<mlir::Value> coordinates;
Fortran::semantics::UltimateComponentIterator components{*derived};
for (const auto &sym : components) {
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
mlir::Type fieldTy;
std::vector<mlir::Value> coordinates;
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
// Field found in the base record type.
auto fieldName = recTy.getTypeList()[fieldIdx].first;
fieldTy = recTy.getTypeList()[fieldIdx].second;
mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(fieldIndex);
} else {
// Field not found in base record type, search in potential
// record type components.
for (auto component : recTy.getTypeList()) {
if (auto childRecTy =
mlir::dyn_cast<fir::RecordType>(component.second)) {
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
mlir::Value parentFieldIndex =
builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(childRecTy.getContext()),
component.first, recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(parentFieldIndex);
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
mlir::Value childFieldIndex =
builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(fieldTy.getContext()),
fieldName, childRecTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(childFieldIndex);
break;
}
}
}
}
if (coordinates.empty())
TODO(loc, "device resident component in complex derived-type "
"hierarchy");
mlir::Value comp = builder.create<fir::CoordinateOp>(
loc, builder.getRefType(fieldTy), fir::getBase(exv), coordinates);
cuf::DataAttributeAttr dataAttr =
Fortran::lower::translateSymbolCUFDataAttribute(
builder.getContext(), sym);
builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
}
}
}
}
}
/// Must \p var be default initialized at runtime when entering its scope.
static bool
mustBeDefaultInitializedAtRuntime(const Fortran::lower::pft::Variable &var) {
@ -1179,6 +1194,9 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
if (mustBeDefaultInitializedAtRuntime(var))
Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
symMap);
if (converter.getFoldingContext().languageFeatures().IsEnabled(
Fortran::common::LanguageFeature::CUDA))
initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
auto *builder = &converter.getFirOpBuilder();
if (needCUDAAlloc(var.getSymbol()) &&
!cuf::isCUDADeviceContext(builder->getRegion())) {
@ -1437,6 +1455,9 @@ static void instantiateAlias(Fortran::lower::AbstractConverter &converter,
if (mustBeDefaultInitializedAtRuntime(var))
Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
symMap);
if (converter.getFoldingContext().languageFeatures().IsEnabled(
Fortran::common::LanguageFeature::CUDA))
initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
}
//===--------------------------------------------------------------===//

View File

@ -12,10 +12,13 @@ contains
end subroutine
! CHECK-LABEL: func.func @_QMm1Psub1()
! CHECK: %[[DT:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]], x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
! CHECK: %[[DT:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
! CHECK: fir.address_of(@_QQ_QMm1Tty_device.DerivedInit)
! CHECK: fir.copy
! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]]#0, x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: cuf.set_allocator_idx %[[X]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]], z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]]#0, z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: cuf.set_allocator_idx %[[Z]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
end module