[flang][cuda] Support data transfer with parenthesis around rhs (#183201)
This commit is contained in:
parent
5524ce826d
commit
a224ba0689
@ -64,7 +64,8 @@ translateSymbolCUFDataAttribute(mlir::MLIRContext *mlirContext,
|
||||
|
||||
/// Check if the rhs has an implicit conversion. Return the elemental op if
|
||||
/// there is a conversion. Return null otherwise.
|
||||
hlfir::ElementalOp isTransferWithConversion(mlir::Value rhs);
|
||||
std::pair<hlfir::ElementalOp, hlfir::ElementalOp>
|
||||
isTransferWithConversion(mlir::Value rhs);
|
||||
|
||||
/// Check if the value is an allocatable with double descriptor.
|
||||
bool hasDoubleDescriptor(mlir::Value);
|
||||
|
||||
@ -5436,11 +5436,12 @@ private:
|
||||
|
||||
// host = device
|
||||
if (!lhsIsDevice && rhsIsDevice) {
|
||||
if (auto elementalOp = Fortran::lower::isTransferWithConversion(rhs)) {
|
||||
auto [firstElOp, elOp] = Fortran::lower::isTransferWithConversion(rhs);
|
||||
if (firstElOp) {
|
||||
mlir::OpBuilder::InsertionGuard insertionGuard(builder);
|
||||
auto designateOp =
|
||||
*elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin();
|
||||
builder.setInsertionPoint(elementalOp);
|
||||
*firstElOp.getBody()->getOps<hlfir::DesignateOp>().begin();
|
||||
builder.setInsertionPoint(firstElOp);
|
||||
// Create a temp to transfer the rhs before applying the conversion.
|
||||
hlfir::Entity entity{designateOp.getMemref()};
|
||||
auto [temp, cleanup] = hlfir::createTempFromMold(loc, builder, entity);
|
||||
@ -5449,8 +5450,8 @@ private:
|
||||
cuf::DataTransferOp::create(builder, loc, designateOp.getMemref(), temp,
|
||||
/*shape=*/mlir::Value{}, transferKindAttr);
|
||||
designateOp.getMemrefMutable().assign(temp);
|
||||
builder.setInsertionPointAfter(elementalOp);
|
||||
hlfir::AssignOp::create(builder, loc, elementalOp, lhs,
|
||||
builder.setInsertionPointAfter(elOp);
|
||||
hlfir::AssignOp::create(builder, loc, elOp, lhs,
|
||||
isWholeAllocatableAssignment,
|
||||
keepLhsLengthInAllocatableAssignment);
|
||||
return;
|
||||
|
||||
@ -68,7 +68,15 @@ cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
|
||||
return cuf::getDataAttribute(mlirContext, cudaAttr);
|
||||
}
|
||||
|
||||
hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
|
||||
std::pair<hlfir::ElementalOp, hlfir::ElementalOp>
|
||||
Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
|
||||
auto isCopyElementalOp = [](hlfir::ElementalOp elOp) {
|
||||
return llvm::hasSingleElement(
|
||||
elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
|
||||
llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
|
||||
llvm::hasSingleElement(
|
||||
elOp.getBody()->getOps<hlfir::NoReassocOp>()) == 1;
|
||||
};
|
||||
auto isConversionElementalOp = [](hlfir::ElementalOp elOp) {
|
||||
return llvm::hasSingleElement(
|
||||
elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
|
||||
@ -76,6 +84,11 @@ hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
|
||||
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
|
||||
1;
|
||||
};
|
||||
auto isConversionFromCopyElementalOp = [](hlfir::ElementalOp elOp) {
|
||||
return llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::ApplyOp>()) &&
|
||||
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
|
||||
1;
|
||||
};
|
||||
if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) {
|
||||
if (!declOp.getMemref().getDefiningOp())
|
||||
return {};
|
||||
@ -84,11 +97,20 @@ hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
|
||||
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(
|
||||
associateOp.getSource().getDefiningOp()))
|
||||
if (isConversionElementalOp(elOp))
|
||||
return elOp;
|
||||
return {elOp, elOp};
|
||||
}
|
||||
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()))
|
||||
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp())) {
|
||||
if (isConversionFromCopyElementalOp(elOp)) {
|
||||
auto applyOp = *elOp.getBody()->getOps<hlfir::ApplyOp>().begin();
|
||||
if (auto firstElOp = mlir::dyn_cast<hlfir::ElementalOp>(
|
||||
applyOp.getExpr().getDefiningOp())) {
|
||||
if (isCopyElementalOp(firstElOp))
|
||||
return {firstElOp, elOp};
|
||||
}
|
||||
}
|
||||
if (isConversionElementalOp(elOp))
|
||||
return elOp;
|
||||
return {elOp, elOp};
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
|
||||
@ -616,3 +616,16 @@ end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPsub32()
|
||||
! CHECK-COUNT-2: cuf.data_transfer
|
||||
|
||||
subroutine sub33(m, n)
|
||||
integer :: m, n
|
||||
real(2), managed :: dc(m,n)
|
||||
real(4) :: c(m,n)
|
||||
|
||||
c = (dc)
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPsub33
|
||||
! CHECK: cuf.data_transfer
|
||||
! CHECK-COUNT-2: hlfir.elemental
|
||||
! CHECK: hlfir.assign
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user