[flang][cuda] Lower set/get default stream for arrays (#181432)

This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2026-02-13 15:44:38 -08:00 committed by GitHub
parent 49fa2a4d24
commit c4170461d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 127 additions and 10 deletions

View File

@ -141,11 +141,9 @@ cudaStream_t RTDECL(CUFGetAssociatedStream)(void *p) {
return nullptr;
}
int RTDECL(CUFSetAssociatedStream)(void *p, cudaStream_t stream, bool hasStat,
const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
Terminator terminator{sourceFile, sourceLine};
int RTDECL(CUFSetAssociatedStream)(void *p, cudaStream_t stream) {
if (p == nullptr) {
return ReturnError(terminator, StatBaseNull, errMsg, hasStat);
return StatBaseNull;
}
int pos = findAllocation(p);
if (pos >= 0) {

View File

@ -205,7 +205,6 @@ TEST(AllocatableAsyncTest, SetStreamTest) {
// REAL(4), DEVICE, ALLOCATABLE :: b(:) - unallocated, base_addr is null
auto b{createAllocatable(TypeCategory::Real, 4)};
int stat2 = RTDECL(CUFSetAssociatedStream)(
b->raw().base_addr, stream, true, nullptr, __FILE__, __LINE__);
int stat2 = RTDECL(CUFSetAssociatedStream)(b->raw().base_addr, stream);
EXPECT_EQ(stat2, StatBaseNull);
}

View File

@ -51,6 +51,12 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genClusterBlockIndex(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genClusterDimBlocks(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue
genCUDASetDefaultStreamArray(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue
genCUDAGetDefaultStreamArg(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>);
template <const char *fctName, int extent>
fir::ExtendedValue genLDXXFunc(mlir::Type,

View File

@ -21,9 +21,7 @@ extern "C" {
void RTDECL(CUFRegisterAllocator)();
cudaStream_t RTDECL(CUFGetAssociatedStream)(void *);
int RTDECL(CUFSetAssociatedStream)(void *, cudaStream_t, bool hasStat = false,
const Descriptor *errMsg = nullptr, const char *sourceFile = nullptr,
int sourceLine = 0);
int RTDECL(CUFSetAssociatedStream)(void *, cudaStream_t);
}
void *CUFAllocPinned(std::size_t, std::int64_t *);

View File

@ -19,6 +19,7 @@
#include "flang/Optimizer/Builder/MutableBox.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@ -382,6 +383,16 @@ static constexpr IntrinsicHandler cudaHandlers[]{
&CI::genClusterDimBlocks),
{},
/*isElemental=*/false},
{"cudagetstreamdefaultarg",
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
&CI::genCUDAGetDefaultStreamArg),
{{{"devptr", asAddr}}},
/*isElemental=*/false},
{"cudasetstreamarray",
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
&CI::genCUDASetDefaultStreamArray),
{{{"devptr", asAddr}, {"stream", asValue}}},
/*isElemental=*/false},
{"fence_proxy_async",
static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
&CI::genFenceProxyAsync),
@ -1103,6 +1114,46 @@ CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
return res;
}
// CUDASETSTREAMARRAY
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStreamArray(
mlir::Type resTy, llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 2);
mlir::Value arg = fir::getBase(args[0]);
mlir::Value stream = fir::getBase(args[1]);
if (mlir::isa<fir::BaseBoxType>(arg.getType()))
arg = fir::BoxAddrOp::create(builder, loc, arg);
mlir::Type i64Ty = builder.getI64Type();
mlir::Type i32Ty = builder.getI32Type();
auto ctx = builder.getContext();
mlir::Type voidPtrTy =
fir::LLVMPointerType::get(ctx, mlir::IntegerType::get(ctx, 8));
mlir::FunctionType ftype =
mlir::FunctionType::get(ctx, {voidPtrTy, i64Ty}, {i32Ty});
mlir::Value voidPtr = builder.createConvert(loc, voidPtrTy, arg);
auto funcOp =
builder.createFunction(loc, RTNAME_STRING(CUFSetAssociatedStream), ftype);
auto call = fir::CallOp::create(builder, loc, funcOp, {voidPtr, stream});
return call.getResult(0);
}
// CUDAGETDEFAULTSTREAMARG
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDAGetDefaultStreamArg(
mlir::Type resultType, llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 1);
mlir::Value devptr = fir::getBase(args[0]);
mlir::Type i64Ty = builder.getI64Type();
auto ctx = builder.getContext();
mlir::Type voidPtrTy =
fir::LLVMPointerType::get(ctx, mlir::IntegerType::get(ctx, 8));
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {voidPtrTy}, {i64Ty});
mlir::Value voidPtr = builder.createConvert(loc, voidPtrTy, devptr);
auto funcOp =
builder.createFunction(loc, RTNAME_STRING(CUFGetAssociatedStream), ftype);
auto call = fir::CallOp::create(builder, loc, funcOp, {voidPtr});
return call.getResult(0);
}
// FENCE_PROXY_ASYNC
void CUDAIntrinsicLibrary::genFenceProxyAsync(
llvm::ArrayRef<fir::ExtendedValue> args) {

View File

@ -0,0 +1,39 @@
!===-- module/cuda_runtime_api.f90 -----------------------------------------===!
!
! Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
! See https://llvm.org/LICENSE.txt for license information.
! SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
!
!===------------------------------------------------------------------------===!
module cuda_runtime_api
implicit none
integer, parameter :: cuda_stream_kind = int_ptr_kind()
interface cudaforgetdefaultstream
integer(kind=cuda_stream_kind) function cudagetstreamdefaultarg(devptr)
import cuda_stream_kind
!DIR$ IGNORE_TKR (TKR) devptr
integer, device :: devptr(*)
end function
integer(kind=cuda_stream_kind) function cudastreamgetdefaultnull()
import cuda_stream_kind
end function
end interface
interface cudaforsetdefaultstream
integer function cudasetdefaultstream(stream)
import cuda_stream_kind
!DIR$ IGNORE_TKR (K) stream
integer(kind=cuda_stream_kind), value :: stream
end function
integer function cudasetstreamarray(devptr, stream)
import cuda_stream_kind
!DIR$ IGNORE_TKR (K) stream, (TKR) devptr
integer, device :: devptr(*)
integer(kind=cuda_stream_kind), value :: stream
end function
end interface
end module cuda_runtime_api

View File

@ -0,0 +1,24 @@
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
subroutine associated_stream
use cuda_runtime_api
integer(kind=cuda_stream_kind) :: strm, strmout
integer, managed, allocatable :: v(:)
integer :: istat
istat = cudaforSetDefaultStream(v, strm)
strmout = cudaforGetDefaultStream(v)
end subroutine
! CHECK-LABEL: func.func @_QPassociated_stream()
! CHECK: %[[ADDR:.*]] = fir.box_addr %{{.*}} : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
! CHECK: %[[STREAM:.*]] = fir.load %{{.*}}#0 : !fir.ref<i64>
! CHECK: %[[VOIDPTR:.*]] = fir.convert %[[ADDR]] : (!fir.heap<!fir.array<?xi32>>) -> !fir.llvm_ptr<i8>
! CHECK: %[[STAT:.*]] = fir.call @_FortranACUFSetAssociatedStream(%[[VOIDPTR]], %[[STREAM]]) fastmath<contract> : (!fir.llvm_ptr<i8>, i64) -> i32
! CHECK: hlfir.assign %[[STAT]] to %{{.*}}#0 : i32, !fir.ref<i32>
! CHECK: %[[ADDR:.*]] = fir.box_addr %{{.*}} : (!fir.box<!fir.heap<!fir.array<?xi32>>>) -> !fir.heap<!fir.array<?xi32>>
! CHECK: %[[VOIDPTR:.*]] = fir.convert %[[ADDR]] : (!fir.heap<!fir.array<?xi32>>) -> !fir.llvm_ptr<i8>
! CHECK: %[[STREAM:.*]] = fir.call @_FortranACUFGetAssociatedStream(%[[VOIDPTR]]) fastmath<contract> : (!fir.llvm_ptr<i8>) -> i64
! CHECK: hlfir.assign %[[STREAM]] to %{{.*}}#0 : i64, !fir.ref<i64>

View File

@ -16,6 +16,7 @@ set(MODULES
"__cuda_builtins"
"__cuda_device"
"cooperative_groups"
"cuda_runtime_api"
"cudadevice"
"ieee_arithmetic"
"ieee_exceptions"
@ -64,7 +65,8 @@ if (NOT CMAKE_CROSSCOMPILING)
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__ppc_types.mod)
elseif(${filename} STREQUAL "__cuda_device" OR
${filename} STREQUAL "cudadevice" OR
${filename} STREQUAL "cooperative_groups")
${filename} STREQUAL "cooperative_groups" OR
${filename} STREQUAL "cuda_runtime_api")
set(opts -fc1 -xcuda)
if(${filename} STREQUAL "__cuda_device")
set(depends ${FLANG_INTRINSIC_MODULES_DIR}/__cuda_builtins.mod)