[flang][cuda] Add support for cudaStreamDestroy (#183648)

Add specific lowering and entry point for cudaStreamDestroy. Since we
keep associated stream for some allocation, we need to reset it when the
stream is destroy so we don't use it anymore.
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2026-02-26 16:24:29 -08:00 committed by GitHub
parent 5e6f0c45a8
commit 26b4c25b8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 96 additions and 4 deletions

View File

@ -119,6 +119,15 @@ static void eraseAllocation(int pos) {
--numDeviceAllocations;
}
void CUFResetStream(cudaStream_t stream) {
CriticalSection critical{lock};
for (int i = 0; i < numDeviceAllocations; ++i) {
if (deviceAllocations[i].stream == stream) {
deviceAllocations[i].stream = nullptr;
}
}
}
extern "C" {
void RTDEF(CUFRegisterAllocator)() {

View File

@ -14,6 +14,7 @@
#include "flang-rt/runtime/lock.h"
#include "flang-rt/runtime/stat.h"
#include "flang-rt/runtime/terminator.h"
#include "flang/Runtime/CUDA/allocator.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Support/Fortran.h"
@ -23,20 +24,25 @@ static thread_local cudaStream_t defaultStream{nullptr};
extern "C" {
int RTDECL(CUFSetDefaultStream)(cudaStream_t stream) {
int RTDEF(CUFSetDefaultStream)(cudaStream_t stream) {
defaultStream = stream;
return StatOk;
}
cudaStream_t RTDECL(CUFGetDefaultStream)() { return defaultStream; }
cudaStream_t RTDEF(CUFGetDefaultStream)() { return defaultStream; }
int RTDECL(CUFStreamSynchronize)(cudaStream_t stream) {
int RTDEF(CUFStreamSynchronize)(cudaStream_t stream) {
return cudaStreamSynchronize(stream);
}
int RTDECL(CUFStreamSynchronizeNull)() {
int RTDEF(CUFStreamSynchronizeNull)() {
return cudaStreamSynchronize(RTNAME(CUFGetDefaultStream)());
}
int RTDEF(CUFStreamDestroy)(cudaStream_t stream) {
CUFResetStream(stream);
return cudaStreamDestroy(stream);
}
}
} // namespace Fortran::runtime::cuda

View File

@ -209,3 +209,38 @@ TEST(AllocatableAsyncTest, SetStreamTest) {
int stat2 = RTDECL(CUFSetAssociatedStream)(b->raw().base_addr, stream);
EXPECT_EQ(stat2, StatBaseNull);
}
TEST(AllocatableAsyncTest, DestroyStreamTest) {
using Fortran::common::TypeCategory;
RTNAME(CUFRegisterAllocator)();
// REAL(4), DEVICE, ALLOCATABLE :: a(:)
auto a{createAllocatable(TypeCategory::Real, 4)};
a->SetAllocIdx(kDeviceAllocatorPos);
EXPECT_EQ((int)kDeviceAllocatorPos, a->GetAllocIdx());
EXPECT_FALSE(a->HasAddendum());
RTNAME(AllocatableSetBounds)(*a, 0, 1, 10);
cudaStream_t stream;
cudaStreamCreate(&stream);
EXPECT_EQ(cudaSuccess, cudaGetLastError());
RTNAME(AllocatableAllocate)
(*a, /*asyncObject=*/(std::int64_t *)&stream, /*hasStat=*/false,
/*errMsg=*/nullptr, __FILE__, __LINE__);
EXPECT_TRUE(a->IsAllocated());
cudaDeviceSynchronize();
EXPECT_EQ(cudaSuccess, cudaGetLastError());
cudaStream_t s = RTNAME(CUFGetAssociatedStream)(a->raw().base_addr);
EXPECT_EQ(s, stream);
RTNAME(CUFStreamDestroy)(stream);
s = RTNAME(CUFGetAssociatedStream)(a->raw().base_addr);
EXPECT_EQ(s, nullptr);
RTNAME(AllocatableDeallocate)
(*a, /*hasStat=*/false, /*errMsg=*/nullptr, __FILE__, __LINE__);
EXPECT_FALSE(a->IsAllocated());
cudaDeviceSynchronize();
EXPECT_EQ(cudaSuccess, cudaGetLastError());
}

View File

@ -65,6 +65,8 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
genCUDAStreamSynchronize(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genCUDAStreamSynchronizeNull(mlir::Type,
llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genCUDAStreamDestroy(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>);
template <const char *fctName, int extent>
fir::ExtendedValue genLDXXFunc(mlir::Type,

View File

@ -23,6 +23,8 @@ int RTDECL(CUFSetAssociatedStream)(void *, cudaStream_t);
void RTDECL(CUFRegisterAllocator)();
}
void CUFResetStream(cudaStream_t stream);
void *CUFAllocPinned(std::size_t, std::int64_t *);
void CUFFreePinned(void *);

View File

@ -23,6 +23,7 @@ int RTDECL(CUFSetDefaultStream)(cudaStream_t);
cudaStream_t RTDECL(CUFGetDefaultStream)();
int RTDECL(CUFStreamSynchronize)(cudaStream_t);
int RTDECL(CUFStreamSynchronizeNull)();
int RTDECL(CUFStreamDestroy)(cudaStream_t);
}
} // namespace Fortran::runtime::cuda

View File

@ -403,6 +403,11 @@ static constexpr IntrinsicHandler cudaHandlers[]{
&CI::genCUDASetDefaultStream),
{{{"stream", asValue}}},
/*isElemental=*/false},
{"cudastreamdestroy",
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
&CI::genCUDAStreamDestroy),
{{{"stream", asValue}}},
/*isElemental=*/false},
{"fence_proxy_async",
static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
&CI::genFenceProxyAsync),
@ -1161,6 +1166,20 @@ fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStreamArray(
return call.getResult(0);
}
// CUDASTREAMDESTROY
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDAStreamDestroy(
mlir::Type resTy, llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 1);
mlir::Value stream = fir::getBase(args[0]);
mlir::Type i64Ty = builder.getI64Type();
auto ctx = builder.getContext();
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {resTy});
auto funcOp =
builder.createFunction(loc, RTNAME_STRING(CUFStreamDestroy), ftype);
auto call = fir::CallOp::create(builder, loc, funcOp, {stream});
return call.getResult(0);
}
// CUDASTREAMSYNCHRONIZE
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDAStreamSynchronize(
mlir::Type resTy, llvm::ArrayRef<fir::ExtendedValue> args) {

View File

@ -36,4 +36,12 @@ interface cudaforsetdefaultstream
end function
end interface
interface cudastreamdestroy
integer function cudastreamdestroy(stream)
import cuda_stream_kind
!DIR$ IGNORE_TKR (K) stream
integer(kind=cuda_stream_kind), value :: stream
end function
end interface
end module cuda_runtime_api

View File

@ -39,3 +39,13 @@ end subroutine
! CHECK: %{{.*}} = fir.call @_FortranACUFGetDefaultStream() fastmath<contract> : () -> i64
! CHECK: %{{.*}} = fir.call @_FortranACUFGetDefaultStream() fastmath<contract> : () -> i64
subroutine stream_destroy
use cuda_runtime_api
integer(kind=cuda_stream_kind) :: strm
integer :: istat
istat = cudaStreamCreate(strm)
istat = cudaStreamDestroy(strm)
end subroutine
! CHECK-LABEL: func.func @_QPstream_destroy()
! CHECK: %{{.*}} = fir.call @_FortranACUFStreamDestroy(%{{.*}}) fastmath<contract> : (i64) -> i32