[OpenACC][MLIR] clone reduction operands during ACCIfClauseLowering (#177196)
Clone the reduction operands into the compute region side. This also fixes an issue where references to acc.reduction remain on the host side.
This commit is contained in:
parent
c6afb03658
commit
49903c4e64
@ -137,6 +137,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
|
||||
SmallVector<Operation *> dataEntryOps;
|
||||
SmallVector<Operation *> dataExitOps;
|
||||
SmallVector<Operation *> firstprivateOps;
|
||||
SmallVector<Operation *> reductionOps;
|
||||
|
||||
// Collect data entry operations
|
||||
for (Value operand : computeConstructOp.getDataClauseOperands()) {
|
||||
@ -150,6 +151,12 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
|
||||
firstprivateOps.push_back(defOp);
|
||||
}
|
||||
|
||||
// Collect reduction operations
|
||||
for (Value operand : computeConstructOp.getReductionOperands()) {
|
||||
if (Operation *defOp = operand.getDefiningOp())
|
||||
reductionOps.push_back(defOp);
|
||||
}
|
||||
|
||||
// Find corresponding exit operations for each entry operation.
|
||||
// Iterate backwards through entry ops since exit ops appear in reverse order.
|
||||
for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
|
||||
@ -171,6 +178,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
|
||||
// Clone data entry operations
|
||||
SmallVector<Value> deviceDataOperands;
|
||||
SmallVector<Value> firstprivateOperands;
|
||||
SmallVector<Value> reductionOperands;
|
||||
|
||||
// Map the data entry and firstprivate ops for the cloned region
|
||||
IRMapping deviceMapping;
|
||||
@ -184,6 +192,11 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
|
||||
firstprivateOperands.push_back(clonedOp->getResult(0));
|
||||
deviceMapping.map(firstprivateOp->getResult(0), clonedOp->getResult(0));
|
||||
}
|
||||
for (Operation *reductionOp : reductionOps) {
|
||||
Operation *clonedOp = rewriter.clone(*reductionOp, deviceMapping);
|
||||
reductionOperands.push_back(clonedOp->getResult(0));
|
||||
deviceMapping.map(reductionOp->getResult(0), clonedOp->getResult(0));
|
||||
}
|
||||
|
||||
// Create new compute op without if condition for device execution by
|
||||
// cloning
|
||||
@ -192,6 +205,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
|
||||
newComputeOp.getIfCondMutable().clear();
|
||||
newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
|
||||
newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
|
||||
newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
|
||||
|
||||
// Clone data exit operations
|
||||
rewriter.setInsertionPointAfter(newComputeOp);
|
||||
@ -238,6 +252,10 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
|
||||
getAccVar(firstprivateOp).replaceAllUsesWith(getVar(firstprivateOp));
|
||||
eraseOps.push_back(firstprivateOp);
|
||||
}
|
||||
for (Operation *reductionOp : reductionOps) {
|
||||
getAccVar(reductionOp).replaceAllUsesWith(getVar(reductionOp));
|
||||
eraseOps.push_back(reductionOp);
|
||||
}
|
||||
}
|
||||
|
||||
void ACCIfClauseLowering::runOnOperation() {
|
||||
|
||||
@ -247,7 +247,7 @@ func.func @test_acc_firstprivate(%arg0: memref<10xi32>, %arg1: memref<i32>, %con
|
||||
%copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32>
|
||||
%firstprivate = acc.firstprivate varPtr(%arg1 : memref<i32>) recipe(@memref_i32) -> memref<i32>
|
||||
|
||||
// In the else branch, uses of %copyin should be replaced with %arg0
|
||||
// In the else branch, uses of %firstprivate should be replaced with %arg0
|
||||
// CHECK: scf.if
|
||||
// CHECK: [[FIRSTPRIVATE:%.*]] = acc.firstprivate varPtr(%arg1 : memref<i32>) recipe(@memref_i32) -> memref<i32>
|
||||
// CHECK: acc.parallel {{.*}} firstprivate([[FIRSTPRIVATE]] : memref<i32>) {
|
||||
@ -268,3 +268,47 @@ func.func @test_acc_firstprivate(%arg0: memref<10xi32>, %arg1: memref<i32>, %con
|
||||
acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test that acc variable uses in host path are replaced with host variables;
|
||||
// and the reduction operands are cloned
|
||||
// CHECK-LABEL: func.func @test_acc_reduction
|
||||
|
||||
acc.reduction.recipe @reduction_add_memref_i32 : memref<i32> reduction_operator <add> init {
|
||||
^bb0(%arg0: memref<i32>):
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%0 = memref.alloca() : memref<i32>
|
||||
memref.store %c0_i32, %0[] : memref<i32>
|
||||
acc.yield %0 : memref<i32>
|
||||
} combiner {
|
||||
^bb0(%arg0: memref<i32>, %arg1: memref<i32>):
|
||||
%0 = memref.load %arg1[] : memref<i32>
|
||||
%1 = memref.load %arg0[] : memref<i32>
|
||||
%2 = arith.addi %1, %0 : i32
|
||||
memref.store %2, %arg0[] : memref<i32>
|
||||
acc.yield %arg0 : memref<i32>
|
||||
}
|
||||
|
||||
func.func @test_acc_reduction(%arg0: memref<i32>, %cond: i1) {
|
||||
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%reduction = acc.reduction varPtr(%arg0 : memref<i32>) recipe(@reduction_add_memref_i32) -> memref<i32>
|
||||
|
||||
// In the else branch, uses of %reduction should be replaced with %arg0
|
||||
// CHECK: scf.if
|
||||
// CHECK: [[REDUCTION:%.*]] = acc.reduction varPtr(%arg0 : memref<i32>) recipe(@reduction_add_memref_i32) -> memref<i32>
|
||||
// CHECK: acc.parallel reduction([[REDUCTION]] : memref<i32>) {
|
||||
// CHECK: } else {
|
||||
// CHECK: [[LOAD:%.*]] = memref.load %arg0[] : memref<i32>
|
||||
// CHECK: memref.store {{.*}}, %arg0[] : memref<i32>
|
||||
// CHECK: }
|
||||
|
||||
acc.parallel reduction(%reduction : memref<i32>) if(%cond) {
|
||||
%load = memref.load %reduction[] : memref<i32>
|
||||
%add = arith.addi %load, %c0_i32 : i32
|
||||
memref.store %add, %reduction[] : memref<i32>
|
||||
acc.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user