From b4e2272271ee85273ca871abac2f6e9342da143d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Tue, 15 Jul 2025 13:52:00 -0700 Subject: [PATCH] [flang][cuda] Move cuf.set_allocator_idx after derived-type init (#148936) Derived type initialization overwrite the component descriptor. Place the `cuf.set_allocator_idx` after the initialization is performed. --- flang/lib/Lower/ConvertVariable.cpp | 167 +++++++++++-------- flang/test/Lower/CUDA/cuda-set-allocator.cuf | 9 +- 2 files changed, 100 insertions(+), 76 deletions(-) diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp index ffe456de5663..2bfa9618aa4b 100644 --- a/flang/lib/Lower/ConvertVariable.cpp +++ b/flang/lib/Lower/ConvertVariable.cpp @@ -771,79 +771,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter, return builder.create(loc, ty, nm, symNm, lenParams, indices); - if (!cuf::isCUDADeviceContext(builder.getRegion())) { - mlir::Value alloc = builder.create( - loc, ty, nm, symNm, dataAttr, lenParams, indices); - if (const auto *details{ - ultimateSymbol - .detailsIf()}) { - const Fortran::semantics::DeclTypeSpec *type{details->type()}; - const Fortran::semantics::DerivedTypeSpec *derived{ - type ? type->AsDerived() : nullptr}; - if (derived) { - Fortran::semantics::UltimateComponentIterator components{*derived}; - auto recTy = mlir::dyn_cast(ty); - - llvm::SmallVector coordinates; - for (const auto &sym : components) { - if (Fortran::semantics::IsDeviceAllocatable(sym)) { - unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); - mlir::Type fieldTy; - std::vector coordinates; - - if (fieldIdx != std::numeric_limits::max()) { - // Field found in the base record type. - auto fieldName = recTy.getTypeList()[fieldIdx].first; - fieldTy = recTy.getTypeList()[fieldIdx].second; - mlir::Value fieldIndex = builder.create( - loc, fir::FieldType::get(fieldTy.getContext()), fieldName, - recTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(fieldIndex); - } else { - // Field not found in base record type, search in potential - // record type components. - for (auto component : recTy.getTypeList()) { - if (auto childRecTy = - mlir::dyn_cast(component.second)) { - fieldIdx = childRecTy.getFieldIndex(sym.name().ToString()); - if (fieldIdx != std::numeric_limits::max()) { - mlir::Value parentFieldIndex = - builder.create( - loc, fir::FieldType::get(childRecTy.getContext()), - component.first, recTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(parentFieldIndex); - auto fieldName = childRecTy.getTypeList()[fieldIdx].first; - fieldTy = childRecTy.getTypeList()[fieldIdx].second; - mlir::Value childFieldIndex = - builder.create( - loc, fir::FieldType::get(fieldTy.getContext()), - fieldName, childRecTy, - /*typeParams=*/mlir::ValueRange{}); - coordinates.push_back(childFieldIndex); - break; - } - } - } - } - - if (coordinates.empty()) - TODO(loc, "device resident component in complex derived-type " - "hierarchy"); - - mlir::Value comp = builder.create( - loc, builder.getRefType(fieldTy), alloc, coordinates); - cuf::DataAttributeAttr dataAttr = - Fortran::lower::translateSymbolCUFDataAttribute( - builder.getContext(), sym); - builder.create(loc, comp, dataAttr); - } - } - } - } - return alloc; - } + if (!cuf::isCUDADeviceContext(builder.getRegion())) + return builder.create(loc, ty, nm, symNm, dataAttr, + lenParams, indices); } // Let the builder do all the heavy lifting. @@ -857,6 +787,91 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter, return res; } +/// Device allocatable components in a derived-type don't have the correct +/// allocator index in their descriptor when they are created. After +/// initialization, cuf.set_allocator_idx operations are inserted to set the +/// correct allocator index for each device component. +static void +initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter, + const Fortran::semantics::Symbol &symbol, + Fortran::lower::SymMap &symMap) { + if (const auto *details{ + symbol.GetUltimate() + .detailsIf()}) { + const Fortran::semantics::DeclTypeSpec *type{details->type()}; + const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived() + : nullptr}; + if (derived) { + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); + mlir::Location loc = converter.getCurrentLocation(); + + fir::ExtendedValue exv = + converter.getSymbolExtendedValue(symbol.GetUltimate(), &symMap); + auto recTy = mlir::dyn_cast( + fir::unwrapRefType(fir::getBase(exv).getType())); + assert(recTy && "expected fir::RecordType"); + + llvm::SmallVector coordinates; + Fortran::semantics::UltimateComponentIterator components{*derived}; + for (const auto &sym : components) { + if (Fortran::semantics::IsDeviceAllocatable(sym)) { + unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); + mlir::Type fieldTy; + std::vector coordinates; + + if (fieldIdx != std::numeric_limits::max()) { + // Field found in the base record type. + auto fieldName = recTy.getTypeList()[fieldIdx].first; + fieldTy = recTy.getTypeList()[fieldIdx].second; + mlir::Value fieldIndex = builder.create( + loc, fir::FieldType::get(fieldTy.getContext()), fieldName, + recTy, + /*typeParams=*/mlir::ValueRange{}); + coordinates.push_back(fieldIndex); + } else { + // Field not found in base record type, search in potential + // record type components. + for (auto component : recTy.getTypeList()) { + if (auto childRecTy = + mlir::dyn_cast(component.second)) { + fieldIdx = childRecTy.getFieldIndex(sym.name().ToString()); + if (fieldIdx != std::numeric_limits::max()) { + mlir::Value parentFieldIndex = + builder.create( + loc, fir::FieldType::get(childRecTy.getContext()), + component.first, recTy, + /*typeParams=*/mlir::ValueRange{}); + coordinates.push_back(parentFieldIndex); + auto fieldName = childRecTy.getTypeList()[fieldIdx].first; + fieldTy = childRecTy.getTypeList()[fieldIdx].second; + mlir::Value childFieldIndex = + builder.create( + loc, fir::FieldType::get(fieldTy.getContext()), + fieldName, childRecTy, + /*typeParams=*/mlir::ValueRange{}); + coordinates.push_back(childFieldIndex); + break; + } + } + } + } + + if (coordinates.empty()) + TODO(loc, "device resident component in complex derived-type " + "hierarchy"); + + mlir::Value comp = builder.create( + loc, builder.getRefType(fieldTy), fir::getBase(exv), coordinates); + cuf::DataAttributeAttr dataAttr = + Fortran::lower::translateSymbolCUFDataAttribute( + builder.getContext(), sym); + builder.create(loc, comp, dataAttr); + } + } + } + } +} + /// Must \p var be default initialized at runtime when entering its scope. static bool mustBeDefaultInitializedAtRuntime(const Fortran::lower::pft::Variable &var) { @@ -1179,6 +1194,9 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter, if (mustBeDefaultInitializedAtRuntime(var)) Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(), symMap); + if (converter.getFoldingContext().languageFeatures().IsEnabled( + Fortran::common::LanguageFeature::CUDA)) + initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap); auto *builder = &converter.getFirOpBuilder(); if (needCUDAAlloc(var.getSymbol()) && !cuf::isCUDADeviceContext(builder->getRegion())) { @@ -1437,6 +1455,9 @@ static void instantiateAlias(Fortran::lower::AbstractConverter &converter, if (mustBeDefaultInitializedAtRuntime(var)) Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(), symMap); + if (converter.getFoldingContext().languageFeatures().IsEnabled( + Fortran::common::LanguageFeature::CUDA)) + initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap); } //===--------------------------------------------------------------===// diff --git a/flang/test/Lower/CUDA/cuda-set-allocator.cuf b/flang/test/Lower/CUDA/cuda-set-allocator.cuf index bf74e012a639..ee89ea38a3fc 100644 --- a/flang/test/Lower/CUDA/cuda-set-allocator.cuf +++ b/flang/test/Lower/CUDA/cuda-set-allocator.cuf @@ -12,10 +12,13 @@ contains end subroutine ! CHECK-LABEL: func.func @_QMm1Psub1() -! CHECK: %[[DT:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box>>,y:i32,z:!fir.box>>}> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref>>,y:i32,z:!fir.box>>}>> -! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]], x : (!fir.ref>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> +! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box>>,y:i32,z:!fir.box>>}> {bindc_name = "a", data_attr = #cuf.cuda, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref>>,y:i32,z:!fir.box>>}>> +! CHECK: %[[DT:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref>>,y:i32,z:!fir.box>>}>>) -> (!fir.ref>>,y:i32,z:!fir.box>>}>>, !fir.ref>>,y:i32,z:!fir.box>>}>>) +! CHECK: fir.address_of(@_QQ_QMm1Tty_device.DerivedInit) +! CHECK: fir.copy +! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]]#0, x : (!fir.ref>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> ! CHECK: cuf.set_allocator_idx %[[X]] : !fir.ref>>> {data_attr = #cuf.cuda} -! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]], z : (!fir.ref>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> +! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]]#0, z : (!fir.ref>>,y:i32,z:!fir.box>>}>>) -> !fir.ref>>> ! CHECK: cuf.set_allocator_idx %[[Z]] : !fir.ref>>> {data_attr = #cuf.cuda} end module