From 47ea8543e26a823a0543bbdf2ff529ec432c09e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?= =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?= =?UTF-8?q?=E3=83=B3=29?= Date: Wed, 22 Oct 2025 06:48:10 -1000 Subject: [PATCH] [flang] Update target rewrite to support workgroup and private attributions (#164515) Some operations like the gpu.func have arguments that need to stay in place while rewriting the signature. This is the case for the workgroup and private attribution. Update the target rewrite pass to be aware of that when adding argument at the end of the function signature. If any trailing arguments are present, the new argument will be inserted just before them. --- flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 24 +++++++-- flang/test/Fir/CUDA/cuda-target-rewrite.mlir | 53 +++++++++++++++++++ flang/tools/fir-opt/fir-opt.cpp | 1 + 3 files changed, 74 insertions(+), 4 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index ac285b5d403d..0776346870c7 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -872,6 +872,14 @@ public: } } + // Count the number of arguments that have to stay in place at the end of + // the argument list. + unsigned trailingArgs = 0; + if constexpr (std::is_same_v) { + trailingArgs = + func.getNumWorkgroupAttributions() + func.getNumPrivateAttributions(); + } + // Convert return value(s) for (auto ty : funcTy.getResults()) llvm::TypeSwitch(ty) @@ -981,6 +989,16 @@ public: } } + // Add the argument at the end if the number of trailing arguments is 0, + // otherwise insert the argument at the appropriate index. + auto addOrInsertArgument = [&](mlir::Type ty, mlir::Location loc) { + unsigned inputIndex = func.front().getArguments().size() - trailingArgs; + auto newArg = trailingArgs == 0 + ? func.front().addArgument(ty, loc) + : func.front().insertArgument(inputIndex, ty, loc); + return newArg; + }; + if (!func.empty()) { // If the function has a body, then apply the fixups to the arguments and // return ops as required. These fixups are done in place. @@ -1117,8 +1135,7 @@ public: // original arguments. (Boxchar arguments.) auto newBufArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto boxTy = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); auto box = fir::EmboxCharOp::create(*rewriter, loc, boxTy, newBufArg, @@ -1133,8 +1150,7 @@ public: // appended after all the original arguments. auto newProcPointerArg = func.front().insertArgument(fixup.index, fixupType, loc); - auto newLenArg = - func.front().addArgument(trailingTys[fixup.second], loc); + auto newLenArg = addOrInsertArgument(trailingTys[fixup.second], loc); auto tupleType = oldArgTys[fixup.index - offset]; rewriter->setInsertionPointToStart(&func.front()); fir::FirOpBuilder builder(*rewriter, getModule()); diff --git a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir index a334934f3172..48fee10f3db9 100644 --- a/flang/test/Fir/CUDA/cuda-target-rewrite.mlir +++ b/flang/test/Fir/CUDA/cuda-target-rewrite.mlir @@ -55,3 +55,56 @@ func.func @main(%arg0: complex) { // CHECK-SAME: (%arg0: f64, %arg1: f64) kernel { // CHECK: gpu.return // CHECK: gpu.launch_func @testmod::@_QPtest blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) : i64 dynamic_shared_memory_size %{{.*}} args(%{{.*}} : f64, %{{.*}} : f64) {cuf.proc_attr = #cuf.cuda_proc} + +// ----- + +module attributes {gpu.container_module, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-linux-gnu"} { + gpu.module @testmod { + gpu.func @_QMbarPfoo(%arg0: f32, %arg1: !fir.ref>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space> + gpu.return + } +// CHECK-LABEL: gpu.func @_QMbarPfoo( +// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref>, %[[CHAR:.*]]: !fir.ref>, %[[LENGTH:.*]]: i64) workgroup(%[[WORKGROUP:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { +// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref>, i64) -> !fir.boxchar<1> +// CHECK: memref.store %{{.*}}, %[[WORKGROUP]][%{{.*}}] : memref<1xf32, #gpu.address_space> + + gpu.func @_QMbarPfoo2(%arg0: f32, %arg1: !fir.ref>, %arg2: !fir.boxchar<1>) workgroup(%arg3 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}, %arg4 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space> + memref.store %arg0, %arg4[%c0] : memref<1xf32, #gpu.address_space> + gpu.return + } +// CHECK-LABEL: gpu.func @_QMbarPfoo2( +// CHECK-SAME: %{{.*}}: f32, %{{.*}}: !fir.ref>, %[[CHAR:.*]]: !fir.ref>, %[[LENGTH:.*]]: i64) workgroup(%[[WG1:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}, %[[WG2:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { +// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref>, i64) -> !fir.boxchar<1> +// CHECK: memref.store %{{.*}}, %[[WG1]][%{{.*}}] : memref<1xf32, #gpu.address_space> +// CHECK: memref.store %{{.*}}, %[[WG2]][%{{.*}}] : memref<1xf32, #gpu.address_space> + + gpu.func @_QMbarPprivate(%arg0: f32, %arg1: !fir.boxchar<1>) workgroup(%arg2 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) private(%arg3 : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space> + memref.store %arg0, %arg3[%c0] : memref<1xf32, #gpu.address_space> + gpu.return + } +// CHECK-LABEL: gpu.func @_QMbarPprivate( +// CHECK-SAME: %{{.*}}: f32, %[[CHAR:.*]]: !fir.ref>, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) private(%[[PRIVATE:.*]] : memref<1xf32, #gpu.address_space> {llvm.align = 16 : i32}) { +// CHECK: %{{.*}} = fir.emboxchar %[[CHAR]], %[[LENGTH]] : (!fir.ref>, i64) -> !fir.boxchar<1> +// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space> +// CHECK: memref.store %{{.*}}, %[[PRIVATE]][%{{.*}}] : memref<1xf32, #gpu.address_space> + + gpu.func @test_with_char_proc(%arg0: f32, %arg1: tuple<() -> (), i64> {fir.char_proc}) workgroup(%arg2 : memref<1xf32, #gpu.address_space>) { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg2[%c0] : memref<1xf32, #gpu.address_space> + gpu.return + } +// CHECK-LABEL: gpu.func @test_with_char_proc( +// CHECK-SAME: %{{.*}}: f32, %[[CHARPROC:.*]]: () -> () {fir.char_proc}, %[[LENGTH:.*]]: i64) workgroup(%[[WG:.*]] : memref<1xf32, #gpu.address_space>) { +// CHECK: %{{.*}} = fir.undefined tuple<() -> (), i64> +// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[CHARPROC]], [0 : index] : (tuple<() -> (), i64>, () -> ()) -> tuple<() -> (), i64> +// CHECK: %{{.*}} = fir.insert_value %{{.*}}, %[[LENGTH]], [1 : index] : (tuple<() -> (), i64>, i64) -> tuple<() -> (), i64> +// CHECK: memref.store %{{.*}}, %[[WG]][%{{.*}}] : memref<1xf32, #gpu.address_space> + } +} + diff --git a/flang/tools/fir-opt/fir-opt.cpp b/flang/tools/fir-opt/fir-opt.cpp index 32b0a1dfa5c7..67d07eee1f4f 100644 --- a/flang/tools/fir-opt/fir-opt.cpp +++ b/flang/tools/fir-opt/fir-opt.cpp @@ -50,6 +50,7 @@ int main(int argc, char **argv) { #endif DialectRegistry registry; fir::support::registerDialects(registry); + registry.insert(); fir::support::addFIRExtensions(registry); return failed(MlirOptMain(argc, argv, "FIR modular optimizer driver\n", registry));