[flang][cuda] Support data transfer with parenthesis around rhs (#183201)

This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2026-02-25 10:44:31 -08:00 committed by GitHub
parent 5524ce826d
commit a224ba0689
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 47 additions and 10 deletions

View File

@ -64,7 +64,8 @@ translateSymbolCUFDataAttribute(mlir::MLIRContext *mlirContext,
/// Check if the rhs has an implicit conversion. Return the elemental op if
/// there is a conversion. Return null otherwise.
hlfir::ElementalOp isTransferWithConversion(mlir::Value rhs);
std::pair<hlfir::ElementalOp, hlfir::ElementalOp>
isTransferWithConversion(mlir::Value rhs);
/// Check if the value is an allocatable with double descriptor.
bool hasDoubleDescriptor(mlir::Value);

View File

@ -5436,11 +5436,12 @@ private:
// host = device
if (!lhsIsDevice && rhsIsDevice) {
if (auto elementalOp = Fortran::lower::isTransferWithConversion(rhs)) {
auto [firstElOp, elOp] = Fortran::lower::isTransferWithConversion(rhs);
if (firstElOp) {
mlir::OpBuilder::InsertionGuard insertionGuard(builder);
auto designateOp =
*elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin();
builder.setInsertionPoint(elementalOp);
*firstElOp.getBody()->getOps<hlfir::DesignateOp>().begin();
builder.setInsertionPoint(firstElOp);
// Create a temp to transfer the rhs before applying the conversion.
hlfir::Entity entity{designateOp.getMemref()};
auto [temp, cleanup] = hlfir::createTempFromMold(loc, builder, entity);
@ -5449,8 +5450,8 @@ private:
cuf::DataTransferOp::create(builder, loc, designateOp.getMemref(), temp,
/*shape=*/mlir::Value{}, transferKindAttr);
designateOp.getMemrefMutable().assign(temp);
builder.setInsertionPointAfter(elementalOp);
hlfir::AssignOp::create(builder, loc, elementalOp, lhs,
builder.setInsertionPointAfter(elOp);
hlfir::AssignOp::create(builder, loc, elOp, lhs,
isWholeAllocatableAssignment,
keepLhsLengthInAllocatableAssignment);
return;

View File

@ -68,7 +68,15 @@ cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
return cuf::getDataAttribute(mlirContext, cudaAttr);
}
hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
std::pair<hlfir::ElementalOp, hlfir::ElementalOp>
Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
auto isCopyElementalOp = [](hlfir::ElementalOp elOp) {
return llvm::hasSingleElement(
elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
llvm::hasSingleElement(
elOp.getBody()->getOps<hlfir::NoReassocOp>()) == 1;
};
auto isConversionElementalOp = [](hlfir::ElementalOp elOp) {
return llvm::hasSingleElement(
elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
@ -76,6 +84,11 @@ hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
1;
};
auto isConversionFromCopyElementalOp = [](hlfir::ElementalOp elOp) {
return llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::ApplyOp>()) &&
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
1;
};
if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) {
if (!declOp.getMemref().getDefiningOp())
return {};
@ -84,11 +97,20 @@ hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(
associateOp.getSource().getDefiningOp()))
if (isConversionElementalOp(elOp))
return elOp;
return {elOp, elOp};
}
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()))
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp())) {
if (isConversionFromCopyElementalOp(elOp)) {
auto applyOp = *elOp.getBody()->getOps<hlfir::ApplyOp>().begin();
if (auto firstElOp = mlir::dyn_cast<hlfir::ElementalOp>(
applyOp.getExpr().getDefiningOp())) {
if (isCopyElementalOp(firstElOp))
return {firstElOp, elOp};
}
}
if (isConversionElementalOp(elOp))
return elOp;
return {elOp, elOp};
}
return {};
}

View File

@ -616,3 +616,16 @@ end subroutine
! CHECK-LABEL: func.func @_QPsub32()
! CHECK-COUNT-2: cuf.data_transfer
subroutine sub33(m, n)
integer :: m, n
real(2), managed :: dc(m,n)
real(4) :: c(m,n)
c = (dc)
end subroutine
! CHECK-LABEL: func.func @_QPsub33
! CHECK: cuf.data_transfer
! CHECK-COUNT-2: hlfir.elemental
! CHECK: hlfir.assign