From 33da12aae70ce26568aa06538329fab0481dcb4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Susan=20Tan=20=28=E3=82=B9-=E3=82=B6=E3=83=B3=E3=80=80?= =?UTF-8?q?=E3=82=BF=E3=83=B3=29?= Date: Wed, 25 Mar 2026 11:09:37 -0400 Subject: [PATCH] [acc] Lower acc if with multi-block host fallback via scf.execute_region (#188350) handle multi-block host fallback regions by wrapping them in scf.execute_region, instead of rejecting with `not yet implemented: region with multiple blocks`. --- .../Transforms/ACCIfClauseLowering.cpp | 37 +++++++++++-------- .../OpenACC/acc-if-clause-lowering.mlir | 36 ++++++++++++++++++ 2 files changed, 58 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp index 9095d7c915fa..71df75958a13 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCIfClauseLowering.cpp @@ -59,6 +59,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h" #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" @@ -215,23 +216,29 @@ void ACCIfClauseLowering::lowerIfClauseForComputeConstruct( scf::YieldOp::create(rewriter, computeConstructOp.getLoc()); // Host execution path (false branch) - if (!computeConstructOp.getRegion().hasOneBlock()) { - accSupport->emitNYI(computeConstructOp.getLoc(), - "region with multiple blocks"); - return; + Region &hostRegion = computeConstructOp.getRegion(); + if (hostRegion.hasOneBlock()) { + // Don't need to clone original ops, just take them and legalize for host. + ifOp.getElseRegion().takeBody(hostRegion); + + // Swap acc yield for scf yield. + Block &elseBlock = ifOp.getElseRegion().front(); + elseBlock.getTerminator()->erase(); + rewriter.setInsertionPointToEnd(&elseBlock); + scf::YieldOp::create(rewriter, computeConstructOp.getLoc()); + + convertHostRegion(computeConstructOp, ifOp.getElseRegion()); + } else { + // scf.if regions must stay single-block. Wrap the original multi-block ACC + // body in scf.execute_region so it can be hosted in the else branch. + Block &elseBlock = ifOp.getElseRegion().front(); + rewriter.setInsertionPoint(elseBlock.getTerminator()); + IRMapping hostMapping; + auto hostExecuteRegion = wrapMultiBlockRegionWithSCFExecuteRegion( + hostRegion, hostMapping, computeConstructOp.getLoc(), rewriter); + convertHostRegion(computeConstructOp, hostExecuteRegion.getRegion()); } - // Don't need to clone original ops, just take them and legalize for host - ifOp.getElseRegion().takeBody(computeConstructOp.getRegion()); - - // Swap acc yield for scf yield - Block &elseBlock = ifOp.getElseRegion().front(); - elseBlock.getTerminator()->erase(); - rewriter.setInsertionPointToEnd(&elseBlock); - scf::YieldOp::create(rewriter, computeConstructOp.getLoc()); - - convertHostRegion(computeConstructOp, ifOp.getElseRegion()); - // The original op is now empty and can be erased eraseOps.push_back(computeConstructOp); diff --git a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir index 75f3a5cd211e..4c88df432b6c 100644 --- a/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir +++ b/mlir/test/Dialect/OpenACC/acc-if-clause-lowering.mlir @@ -37,6 +37,42 @@ func.func @test_parallel_if(%arg0: memref<10xi32>, %cond: i1) { // ----- +// Test acc.parallel if lowering when host fallback region has multiple blocks. +// CHECK-LABEL: func.func @test_parallel_if_multiblock +func.func @test_parallel_if_multiblock(%cond: i1, %n: i32) { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %counter = memref.alloca() : memref + memref.store %n, %counter[] : memref + + // CHECK-NOT: acc.parallel if + // CHECK: scf.if %{{.*}} { + // CHECK: acc.parallel { + // CHECK: } else { + // CHECK: scf.execute_region { + // CHECK: ^bb + // CHECK: cf.cond_br + // CHECK: scf.yield + // CHECK: } + // CHECK: } + acc.parallel if(%cond) { + cf.br ^bb1 + ^bb1: + %v = memref.load %counter[] : memref + %pred = arith.cmpi sgt, %v, %c0_i32 : i32 + cf.cond_br %pred, ^bb2, ^bb3 + ^bb2: + %next = arith.subi %v, %c1_i32 : i32 + memref.store %next, %counter[] : memref + cf.br ^bb1 + ^bb3: + acc.yield + } + return +} + +// ----- + // Test acc.kernels with if condition // CHECK-LABEL: func.func @test_kernels_if func.func @test_kernels_if(%arg0: memref<5xi32>, %cond: i1) {