diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp index ddefa1653f21..c103ba29ed28 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp @@ -137,6 +137,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct( SmallVector dataEntryOps; SmallVector dataExitOps; SmallVector firstprivateOps; + SmallVector reductionOps; // Collect data entry operations for (Value operand : computeConstructOp.getDataClauseOperands()) { @@ -150,6 +151,12 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct( firstprivateOps.push_back(defOp); } + // Collect reduction operations + for (Value operand : computeConstructOp.getReductionOperands()) { + if (Operation *defOp = operand.getDefiningOp()) + reductionOps.push_back(defOp); + } + // Find corresponding exit operations for each entry operation. // Iterate backwards through entry ops since exit ops appear in reverse order. for (Operation *dataEntryOp : llvm::reverse(dataEntryOps)) @@ -171,6 +178,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct( // Clone data entry operations SmallVector deviceDataOperands; SmallVector firstprivateOperands; + SmallVector reductionOperands; // Map the data entry and firstprivate ops for the cloned region IRMapping deviceMapping; @@ -184,6 +192,11 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct( firstprivateOperands.push_back(clonedOp->getResult(0)); deviceMapping.map(firstprivateOp->getResult(0), clonedOp->getResult(0)); } + for (Operation *reductionOp : reductionOps) { + Operation *clonedOp = rewriter.clone(*reductionOp, deviceMapping); + reductionOperands.push_back(clonedOp->getResult(0)); + deviceMapping.map(reductionOp->getResult(0), clonedOp->getResult(0)); + } // Create new compute op without if condition for device execution by // cloning @@ -192,6 +205,7 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct( newComputeOp.getIfCondMutable().clear(); newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands); newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands); + newComputeOp.getReductionOperandsMutable().assign(reductionOperands); // Clone data exit operations rewriter.setInsertionPointAfter(newComputeOp); @@ -238,6 +252,10 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct( getAccVar(firstprivateOp).replaceAllUsesWith(getVar(firstprivateOp)); eraseOps.push_back(firstprivateOp); } + for (Operation *reductionOp : reductionOps) { + getAccVar(reductionOp).replaceAllUsesWith(getVar(reductionOp)); + eraseOps.push_back(reductionOp); + } } void ACCIfClauseLowering::runOnOperation() { diff --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir index fdef532fb8cb..5b942c121d56 100644 --- a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir +++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir @@ -247,7 +247,7 @@ func.func @test_acc_firstprivate(%arg0: memref<10xi32>, %arg1: memref, %con %copyin = acc.copyin varPtr(%arg0 : memref<10xi32>) -> memref<10xi32> %firstprivate = acc.firstprivate varPtr(%arg1 : memref) recipe(@memref_i32) -> memref - // In the else branch, uses of %copyin should be replaced with %arg0 + // In the else branch, uses of %firstprivate should be replaced with %arg0 // CHECK: scf.if // CHECK: [[FIRSTPRIVATE:%.*]] = acc.firstprivate varPtr(%arg1 : memref) recipe(@memref_i32) -> memref // CHECK: acc.parallel {{.*}} firstprivate([[FIRSTPRIVATE]] : memref) { @@ -268,3 +268,47 @@ func.func @test_acc_firstprivate(%arg0: memref<10xi32>, %arg1: memref, %con acc.copyout accPtr(%copyin : memref<10xi32>) to varPtr(%arg0 : memref<10xi32>) return } + +// ----- + +// Test that acc variable uses in host path are replaced with host variables; +// and the reduction operands are cloned +// CHECK-LABEL: func.func @test_acc_reduction + +acc.reduction.recipe @reduction_add_memref_i32 : memref reduction_operator init { +^bb0(%arg0: memref): + %c0_i32 = arith.constant 0 : i32 + %0 = memref.alloca() : memref + memref.store %c0_i32, %0[] : memref + acc.yield %0 : memref +} combiner { +^bb0(%arg0: memref, %arg1: memref): + %0 = memref.load %arg1[] : memref + %1 = memref.load %arg0[] : memref + %2 = arith.addi %1, %0 : i32 + memref.store %2, %arg0[] : memref + acc.yield %arg0 : memref +} + +func.func @test_acc_reduction(%arg0: memref, %cond: i1) { + + %c0_i32 = arith.constant 0 : i32 + %reduction = acc.reduction varPtr(%arg0 : memref) recipe(@reduction_add_memref_i32) -> memref + + // In the else branch, uses of %reduction should be replaced with %arg0 + // CHECK: scf.if + // CHECK: [[REDUCTION:%.*]] = acc.reduction varPtr(%arg0 : memref) recipe(@reduction_add_memref_i32) -> memref + // CHECK: acc.parallel reduction([[REDUCTION]] : memref) { + // CHECK: } else { + // CHECK: [[LOAD:%.*]] = memref.load %arg0[] : memref + // CHECK: memref.store {{.*}}, %arg0[] : memref + // CHECK: } + + acc.parallel reduction(%reduction : memref) if(%cond) { + %load = memref.load %reduction[] : memref + %add = arith.addi %load, %c0_i32 : i32 + memref.store %add, %reduction[] : memref + acc.yield + } + return +}