[acc] Lower acc if with multi-block host fallback via scf.execute_region (#188350)
handle multi-block host fallback regions by wrapping them in scf.execute_region, instead of rejecting with `not yet implemented: region with multiple blocks`.
This commit is contained in:
parent
7aaec28fde
commit
33da12aae7
@ -59,6 +59,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
||||
#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
@ -215,23 +216,29 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
|
||||
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
|
||||
|
||||
// Host execution path (false branch)
|
||||
if (!computeConstructOp.getRegion().hasOneBlock()) {
|
||||
accSupport->emitNYI(computeConstructOp.getLoc(),
|
||||
"region with multiple blocks");
|
||||
return;
|
||||
Region &hostRegion = computeConstructOp.getRegion();
|
||||
if (hostRegion.hasOneBlock()) {
|
||||
// Don't need to clone original ops, just take them and legalize for host.
|
||||
ifOp.getElseRegion().takeBody(hostRegion);
|
||||
|
||||
// Swap acc yield for scf yield.
|
||||
Block &elseBlock = ifOp.getElseRegion().front();
|
||||
elseBlock.getTerminator()->erase();
|
||||
rewriter.setInsertionPointToEnd(&elseBlock);
|
||||
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
|
||||
|
||||
convertHostRegion(computeConstructOp, ifOp.getElseRegion());
|
||||
} else {
|
||||
// scf.if regions must stay single-block. Wrap the original multi-block ACC
|
||||
// body in scf.execute_region so it can be hosted in the else branch.
|
||||
Block &elseBlock = ifOp.getElseRegion().front();
|
||||
rewriter.setInsertionPoint(elseBlock.getTerminator());
|
||||
IRMapping hostMapping;
|
||||
auto hostExecuteRegion = wrapMultiBlockRegionWithSCFExecuteRegion(
|
||||
hostRegion, hostMapping, computeConstructOp.getLoc(), rewriter);
|
||||
convertHostRegion(computeConstructOp, hostExecuteRegion.getRegion());
|
||||
}
|
||||
|
||||
// Don't need to clone original ops, just take them and legalize for host
|
||||
ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
|
||||
|
||||
// Swap acc yield for scf yield
|
||||
Block &elseBlock = ifOp.getElseRegion().front();
|
||||
elseBlock.getTerminator()->erase();
|
||||
rewriter.setInsertionPointToEnd(&elseBlock);
|
||||
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
|
||||
|
||||
convertHostRegion(computeConstructOp, ifOp.getElseRegion());
|
||||
|
||||
// The original op is now empty and can be erased
|
||||
eraseOps.push_back(computeConstructOp);
|
||||
|
||||
|
||||
@ -37,6 +37,42 @@ func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) {
|
||||
|
||||
// -----
|
||||
|
||||
// Test acc.parallel if lowering when host fallback region has multiple blocks.
|
||||
// CHECK-LABEL: func.func @test_parallel_if_multiblock
|
||||
func.func @test_parallel_if_multiblock(%cond: i1, %n: i32) {
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%counter = memref.alloca() : memref<i32>
|
||||
memref.store %n, %counter[] : memref<i32>
|
||||
|
||||
// CHECK-NOT: acc.parallel if
|
||||
// CHECK: scf.if %{{.*}} {
|
||||
// CHECK: acc.parallel {
|
||||
// CHECK: } else {
|
||||
// CHECK: scf.execute_region {
|
||||
// CHECK: ^bb
|
||||
// CHECK: cf.cond_br
|
||||
// CHECK: scf.yield
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
acc.parallel if(%cond) {
|
||||
cf.br ^bb1
|
||||
^bb1:
|
||||
%v = memref.load %counter[] : memref<i32>
|
||||
%pred = arith.cmpi sgt, %v, %c0_i32 : i32
|
||||
cf.cond_br %pred, ^bb2, ^bb3
|
||||
^bb2:
|
||||
%next = arith.subi %v, %c1_i32 : i32
|
||||
memref.store %next, %counter[] : memref<i32>
|
||||
cf.br ^bb1
|
||||
^bb3:
|
||||
acc.yield
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test acc.kernels with if condition
|
||||
// CHECK-LABEL: func.func @test_kernels_if
|
||||
func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user