[acc] Lower acc if with multi-block host fallback via scf.execute_region (#188350)

handle multi-block host fallback regions by wrapping them in
scf.execute_region, instead of rejecting with `not yet implemented:
region with multiple blocks`.
This commit is contained in:
Susan Tan (ス-ザン タン) 2026-03-25 11:09:37 -04:00 committed by GitHub
parent 7aaec28fde
commit 33da12aae7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 58 additions and 15 deletions

View File

@ -59,6 +59,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
#include "mlir/Dialect/OpenACC/OpenACC.h"
#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
@ -215,23 +216,29 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
// Host execution path (false branch)
if (!computeConstructOp.getRegion().hasOneBlock()) {
accSupport->emitNYI(computeConstructOp.getLoc(),
"region with multiple blocks");
return;
Region &hostRegion = computeConstructOp.getRegion();
if (hostRegion.hasOneBlock()) {
// Don't need to clone original ops, just take them and legalize for host.
ifOp.getElseRegion().takeBody(hostRegion);
// Swap acc yield for scf yield.
Block &elseBlock = ifOp.getElseRegion().front();
elseBlock.getTerminator()->erase();
rewriter.setInsertionPointToEnd(&elseBlock);
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
convertHostRegion(computeConstructOp, ifOp.getElseRegion());
} else {
// scf.if regions must stay single-block. Wrap the original multi-block ACC
// body in scf.execute_region so it can be hosted in the else branch.
Block &elseBlock = ifOp.getElseRegion().front();
rewriter.setInsertionPoint(elseBlock.getTerminator());
IRMapping hostMapping;
auto hostExecuteRegion = wrapMultiBlockRegionWithSCFExecuteRegion(
hostRegion, hostMapping, computeConstructOp.getLoc(), rewriter);
convertHostRegion(computeConstructOp, hostExecuteRegion.getRegion());
}
// Don't need to clone original ops, just take them and legalize for host
ifOp.getElseRegion().takeBody(computeConstructOp.getRegion());
// Swap acc yield for scf yield
Block &elseBlock = ifOp.getElseRegion().front();
elseBlock.getTerminator()->erase();
rewriter.setInsertionPointToEnd(&elseBlock);
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
convertHostRegion(computeConstructOp, ifOp.getElseRegion());
// The original op is now empty and can be erased
eraseOps.push_back(computeConstructOp);

View File

@ -37,6 +37,42 @@ func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) {
// -----
// Test acc.parallel if lowering when host fallback region has multiple blocks.
// CHECK-LABEL: func.func @test_parallel_if_multiblock
func.func @test_parallel_if_multiblock(%cond: i1, %n: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%counter = memref.alloca() : memref<i32>
memref.store %n, %counter[] : memref<i32>
// CHECK-NOT: acc.parallel if
// CHECK: scf.if %{{.*}} {
// CHECK: acc.parallel {
// CHECK: } else {
// CHECK: scf.execute_region {
// CHECK: ^bb
// CHECK: cf.cond_br
// CHECK: scf.yield
// CHECK: }
// CHECK: }
acc.parallel if(%cond) {
cf.br ^bb1
^bb1:
%v = memref.load %counter[] : memref<i32>
%pred = arith.cmpi sgt, %v, %c0_i32 : i32
cf.cond_br %pred, ^bb2, ^bb3
^bb2:
%next = arith.subi %v, %c1_i32 : i32
memref.store %next, %counter[] : memref<i32>
cf.br ^bb1
^bb3:
acc.yield
}
return
}
// -----
// Test acc.kernels with if condition
// CHECK-LABEL: func.func @test_kernels_if
func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {