[Flang][OpenMP] Don't generate code for unreachable target regions. (#178937)
When a target region is placed inside a constant false condition (e.g.,
`if (.false.)`), the dead code gets eliminated on the host side,
removing the `omp.target` operation entirely. However, the device-side
compilation pipeline is unaware of this elimination and attempts to
generate kernel code. Since the host never created offload metadata for
the eliminated target, the device-side kernel function lacks the
"kernel" attribute, causing `OpenMPOpt` to fail with an assertion when
it expects all outlined kernels to have this attribute. The problem can
be seen with the following code:
```fortran
program cele
implicit none
real :: V
integer :: i
if (.false.) then
!$omp target teams distribute parallel do
do i = 1, 5
V = V * 2
end do
!$omp end target teams distribute parallel do
end if
end program
```
It currently fails with the following assertion:
```
Assertion `omp::isOpenMPKernel(*Kernel) && "Expected kernel function!"' failed.
llvm/lib/Transforms/IPO/OpenMPOpt.cpp:4291
```
This PR adds `DeleteUnreachableTargetsPass` that identifies `omp.target`
operations in unreachable code blocks and removes them.
This commit is contained in:
parent
3f0f8349ac
commit
deedc7bfe3
@ -41,6 +41,18 @@ def MarkDeclareTargetPass
|
||||
let dependentDialects = ["mlir::omp::OpenMPDialect"];
|
||||
}
|
||||
|
||||
def DeleteUnreachableTargetsPass
|
||||
: Pass<"omp-delete-unreachable-targets", "mlir::ModuleOp"> {
|
||||
let summary = "Deletes OpenMP target operations in unreachable code";
|
||||
let description = [{
|
||||
Identifies and removes OpenMP target operations that reside in unreachable
|
||||
code (e.g., inside if(.false.) blocks). This ensures consistency between
|
||||
host and device compilation by preventing unreachable targets from being
|
||||
processed on the device side.
|
||||
}];
|
||||
let dependentDialects = ["mlir::omp::OpenMPDialect"];
|
||||
}
|
||||
|
||||
def FunctionFilteringPass : Pass<"omp-function-filtering"> {
|
||||
let summary = "Filters out functions intended for the host when compiling "
|
||||
"for the target device.";
|
||||
|
||||
@ -8,6 +8,7 @@ add_flang_library(FlangOpenMPTransforms
|
||||
MapsForPrivatizedSymbols.cpp
|
||||
MapInfoFinalization.cpp
|
||||
MarkDeclareTarget.cpp
|
||||
DeleteUnreachableTargets.cpp
|
||||
LowerWorkdistribute.cpp
|
||||
LowerWorkshare.cpp
|
||||
LowerNontemporal.cpp
|
||||
|
||||
79
flang/lib/Optimizer/OpenMP/DeleteUnreachableTargets.cpp
Normal file
79
flang/lib/Optimizer/OpenMP/DeleteUnreachableTargets.cpp
Normal file
@ -0,0 +1,79 @@
|
||||
//===- DeleteUnreachableTargets.cpp --------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This pass removes OpenMP target operations that are in unreachable code.
|
||||
// This ensures host and device compilation have consistent target regions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "flang/Optimizer/Dialect/FIRDialect.h"
|
||||
#include "flang/Optimizer/Dialect/FIROps.h"
|
||||
#include "flang/Optimizer/OpenMP/Passes.h"
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/Utils.h"
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace flangomp {
|
||||
#define GEN_PASS_DEF_DELETEUNREACHABLETARGETSPASS
|
||||
#include "flang/Optimizer/OpenMP/Passes.h.inc"
|
||||
} // namespace flangomp
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Check if an operation is unreachable using DeadCodeAnalysis.
|
||||
static bool isOperationUnreachable(Operation *op, DataFlowSolver &solver) {
|
||||
Block *block = op->getBlock();
|
||||
if (!block)
|
||||
return false;
|
||||
|
||||
// Query DeadCodeAnalysis to check if the block is live (reachable).
|
||||
ProgramPoint *point = solver.getProgramPointBefore(block);
|
||||
const dataflow::Executable *executable =
|
||||
solver.lookupState<dataflow::Executable>(point);
|
||||
|
||||
return (executable && !executable->isLive());
|
||||
}
|
||||
|
||||
class DeleteUnreachableTargetsPass
|
||||
: public flangomp::impl::DeleteUnreachableTargetsPassBase<
|
||||
DeleteUnreachableTargetsPass> {
|
||||
public:
|
||||
DeleteUnreachableTargetsPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
DataFlowSolver solver;
|
||||
dataflow::loadBaselineAnalyses(solver);
|
||||
|
||||
if (failed(solver.initializeAndRun(module))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect unreachable target operations
|
||||
SmallVector<omp::TargetOp> unreachableTargets;
|
||||
module.walk([&](omp::TargetOp targetOp) {
|
||||
if (isOperationUnreachable(targetOp.getOperation(), solver))
|
||||
unreachableTargets.push_back(targetOp);
|
||||
});
|
||||
|
||||
// Delete unreachable target operations
|
||||
for (omp::TargetOp targetOp : unreachableTargets)
|
||||
targetOp->erase();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@ -122,8 +122,9 @@ public:
|
||||
// offloading can be supported.
|
||||
bool hasTargetRegion =
|
||||
funcOp
|
||||
->walk<WalkOrder::PreOrder>(
|
||||
[&](omp::TargetOp) { return WalkResult::interrupt(); })
|
||||
->walk<WalkOrder::PreOrder>([&](omp::TargetOp targetOp) {
|
||||
return WalkResult::interrupt();
|
||||
})
|
||||
.wasInterrupted();
|
||||
|
||||
omp::DeclareTargetDeviceType declareType =
|
||||
|
||||
@ -348,6 +348,11 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm,
|
||||
pm.addPass(flangomp::createAutomapToTargetDataPass());
|
||||
pm.addPass(flangomp::createMapInfoFinalizationPass());
|
||||
pm.addPass(flangomp::createMarkDeclareTargetPass());
|
||||
|
||||
// Delete unreachable target operations before FunctionFilteringPass
|
||||
// extracts them.
|
||||
pm.addPass(flangomp::createDeleteUnreachableTargetsPass());
|
||||
|
||||
pm.addPass(flangomp::createGenericLoopConversionPass());
|
||||
if (opts.isTargetDevice)
|
||||
pm.addPass(flangomp::createFunctionFilteringPass());
|
||||
|
||||
88
flang/test/Lower/OpenMP/target-dead-code.f90
Normal file
88
flang/test/Lower/OpenMP/target-dead-code.f90
Normal file
@ -0,0 +1,88 @@
|
||||
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s --check-prefix=FIR
|
||||
|
||||
! Test that OpenMP target regions in dead code are deleted
|
||||
|
||||
! Test 1: if (.false.) with target - target should be deleted
|
||||
! FIR-LABEL: func.func @_QPtest_dead_simple
|
||||
! FIR: %[[FALSE:.*]] = arith.constant false
|
||||
! FIR: fir.if %[[FALSE]] {
|
||||
! FIR-NOT: omp.target
|
||||
subroutine test_dead_simple()
|
||||
real :: v
|
||||
if (.false.) then
|
||||
!$omp target map(tofrom:v)
|
||||
v = 1.0
|
||||
!$omp end target
|
||||
end if
|
||||
end subroutine
|
||||
|
||||
! Test 2: Live target - should remain
|
||||
! FIR-LABEL: func.func @_QPtest_live_simple
|
||||
! FIR: omp.target
|
||||
subroutine test_live_simple()
|
||||
real :: v
|
||||
!$omp target map(tofrom:v)
|
||||
v = 2.0
|
||||
!$omp end target
|
||||
end subroutine
|
||||
|
||||
! Test 3: Mixed dead and live
|
||||
! FIR-LABEL: func.func @_QPtest_mixed
|
||||
subroutine test_mixed()
|
||||
real :: v
|
||||
! Dead - should be deleted
|
||||
! FIR: fir.if %{{.*}} {
|
||||
if (.false.) then
|
||||
!$omp target map(tofrom:v)
|
||||
v = 3.0
|
||||
!$omp end target
|
||||
end if
|
||||
! FIR-NOT: omp.target
|
||||
! Live - should remain (expect exactly 1 omp.target in function)
|
||||
!$omp target map(tofrom:v)
|
||||
! FIR: omp.target
|
||||
v = 4.0
|
||||
!$omp end target
|
||||
end subroutine
|
||||
|
||||
! Test 4: Nested - outer false, target should be deleted
|
||||
! FIR-LABEL: func.func @_QPtest_nested_outer_false
|
||||
subroutine test_nested_outer_false()
|
||||
real :: v
|
||||
! FIR: fir.if %{{.*}} {
|
||||
if (.false.) then
|
||||
if (.true.) then
|
||||
!$omp target map(tofrom:v)
|
||||
v = 5.0
|
||||
!$omp end target
|
||||
end if
|
||||
end if
|
||||
! FIR-NOT: omp.target
|
||||
end subroutine
|
||||
|
||||
! Test 5: Parameter constant - target should be deleted
|
||||
! FIR-LABEL: func.func @_QPtest_parameter
|
||||
subroutine test_parameter()
|
||||
real :: v
|
||||
logical, parameter :: DEAD = .false.
|
||||
! FIR: fir.if %{{.*}} {
|
||||
if (DEAD) then
|
||||
!$omp target map(tofrom:v)
|
||||
v = 6.0
|
||||
!$omp end target
|
||||
end if
|
||||
! FIR-NOT: omp.target
|
||||
end subroutine
|
||||
|
||||
! FIR-LABEL: func.func @_QPtest_outer
|
||||
subroutine test_outer
|
||||
implicit none
|
||||
contains
|
||||
subroutine unused_sub()
|
||||
real :: v
|
||||
!$omp target map(tofrom: v)
|
||||
v = 5.0
|
||||
!$omp end target
|
||||
end subroutine
|
||||
! FIR-NOT: omp.target
|
||||
end subroutine
|
||||
322
flang/test/Transforms/OpenMP/delete-unreachable-targets.mlir
Normal file
322
flang/test/Transforms/OpenMP/delete-unreachable-targets.mlir
Normal file
@ -0,0 +1,322 @@
|
||||
// RUN: fir-opt --omp-delete-unreachable-targets %s | FileCheck %s
|
||||
|
||||
// This test verifies that OpenMP target operations in unreachable code are
|
||||
// deleted.
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @test_if_false_simple
|
||||
func.func @test_if_false_simple() {
|
||||
%false = arith.constant false
|
||||
// The target in the dead branch should be removed
|
||||
// CHECK: fir.if %false {
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: }
|
||||
fir.if %false {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_if_true_simple
|
||||
func.func @test_if_true_simple() {
|
||||
%true = arith.constant true
|
||||
// The target should remain since the branch is reachable
|
||||
// CHECK: omp.target
|
||||
fir.if %true {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_nested_outer_false
|
||||
func.func @test_nested_outer_false() {
|
||||
%false = arith.constant false
|
||||
%true = arith.constant true
|
||||
// Outer false makes the whole nested structure unreachable
|
||||
// CHECK: fir.if %false {
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: }
|
||||
fir.if %false {
|
||||
fir.if %true {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_nested_inner_false
|
||||
func.func @test_nested_inner_false() {
|
||||
%false = arith.constant false
|
||||
%true = arith.constant true
|
||||
// Outer true, inner false - target should be removed
|
||||
// CHECK: fir.if %true {
|
||||
// CHECK: fir.if %false {
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: }
|
||||
fir.if %true {
|
||||
fir.if %false {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_nested_both_true
|
||||
func.func @test_nested_both_true() {
|
||||
%true1 = arith.constant true
|
||||
%true2 = arith.constant true
|
||||
// CHECK: omp.target
|
||||
fir.if %true1 {
|
||||
fir.if %true2 {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_mixed_targets
|
||||
func.func @test_mixed_targets() {
|
||||
%false = arith.constant false
|
||||
%true = arith.constant true
|
||||
|
||||
// Live target - should remain (expect 2 targets total in output)
|
||||
// CHECK: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
|
||||
// Another live target in if (true) - should remain
|
||||
// CHECK: omp.target
|
||||
fir.if %true {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
|
||||
// Dead target - will be removed
|
||||
// CHECK-NOT: omp.target
|
||||
fir.if %false {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_multiple_dead_targets
|
||||
func.func @test_multiple_dead_targets() {
|
||||
%false = arith.constant false
|
||||
|
||||
// All targets inside dead branch should be removed
|
||||
// CHECK-NOT: omp.target
|
||||
fir.if %false {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_if_else_false
|
||||
func.func @test_if_else_false() {
|
||||
%false = arith.constant false
|
||||
|
||||
// CHECK: fir.if %false {
|
||||
fir.if %false {
|
||||
// Then branch is unreachable, target should be deleted
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
} else {
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: } else {
|
||||
// Else branch is reachable, target should remain
|
||||
// CHECK: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_runtime_condition
|
||||
func.func @test_runtime_condition(%arg0: i1) {
|
||||
// Runtime condition - cannot be optimized, should remain unchanged
|
||||
// CHECK: fir.if %arg0 {
|
||||
fir.if %arg0 {
|
||||
// CHECK: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test that targets nested in structured control flow within unreachable blocks
|
||||
// are correctly identified as unreachable
|
||||
// CHECK-LABEL: func.func @test_nested_in_unreachable_block
|
||||
func.func @test_nested_in_unreachable_block() {
|
||||
cf.br ^bb2
|
||||
^bb1:
|
||||
// This entire block is unreachable
|
||||
// Even though the fir.if condition is true, the whole block is dead
|
||||
%true = arith.constant true
|
||||
// CHECK: ^bb1:
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: cf.br ^bb2
|
||||
fir.if %true {
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
}
|
||||
cf.br ^bb2
|
||||
^bb2:
|
||||
// CHECK: ^bb2:
|
||||
// CHECK-NEXT: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_unreachable_block_after_branch
|
||||
func.func @test_unreachable_block_after_branch() {
|
||||
cf.br ^bb2
|
||||
^bb1:
|
||||
// This block is unreachable - no predecessor branches to it
|
||||
// CHECK: ^bb1:
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: cf.br ^bb2
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
cf.br ^bb2
|
||||
^bb2:
|
||||
// This block is reachable
|
||||
// CHECK: ^bb2:
|
||||
// CHECK-NEXT: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_multiple_unreachable_blocks
|
||||
func.func @test_multiple_unreachable_blocks() {
|
||||
cf.br ^bb3
|
||||
^bb1:
|
||||
// Unreachable block - no predecessor branches to it
|
||||
// CHECK: ^bb1:
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: cf.br ^bb2
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
cf.br ^bb2
|
||||
^bb2:
|
||||
// Also unreachable - only reachable from ^bb1 which is itself unreachable
|
||||
// CHECK: ^bb2:
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: return
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
^bb3:
|
||||
// Reachable from entry
|
||||
// CHECK: ^bb3:
|
||||
// CHECK-NEXT: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_both_branches_reachable
|
||||
func.func @test_both_branches_reachable(%arg0: i1) {
|
||||
cf.cond_br %arg0, ^bb1, ^bb2
|
||||
^bb1:
|
||||
// CHECK: ^bb1:
|
||||
// CHECK-NEXT: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
cf.br ^bb3
|
||||
^bb2:
|
||||
// CHECK: ^bb2:
|
||||
// CHECK-NEXT: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
cf.br ^bb3
|
||||
^bb3:
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_disconnected_block
|
||||
func.func @test_disconnected_block() {
|
||||
// Entry goes directly to exit
|
||||
cf.br ^bb2
|
||||
^bb1:
|
||||
// This block is completely disconnected - no way to reach it
|
||||
// CHECK: ^bb1:
|
||||
// CHECK-NOT: omp.target
|
||||
// CHECK: cf.br ^bb2
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
cf.br ^bb2
|
||||
^bb2:
|
||||
// Reachable from entry
|
||||
// CHECK: ^bb2:
|
||||
// CHECK-NEXT: omp.target
|
||||
omp.target {
|
||||
omp.terminator
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -6443,6 +6443,7 @@ static LogicalResult
|
||||
convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
|
||||
LLVM::ModuleTranslation &moduleTranslation) {
|
||||
auto targetOp = cast<omp::TargetOp>(opInst);
|
||||
|
||||
// The current debug location already has the DISubprogram for the outlined
|
||||
// function that will be created for the target op. We save it here so that
|
||||
// we can set it on the outlined function.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user