[flang][cuda] Lower set/get default stream (#181775)

This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2026-02-17 09:32:04 -08:00 committed by GitHub
parent 5addddf8f1
commit 3c32747a7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 58 additions and 2 deletions

View File

@ -25,3 +25,18 @@ TEST(DefaultStreamTest, GetAndSetTest) {
cudaStream_t outStream = RTDECL(CUFGetDefaultStream)();
EXPECT_EQ(outStream, stream);
}
TEST(DefaultStreamTest, GetAndSetArrayTest) {
using Fortran::common::TypeCategory;
cudaStream_t defaultStream = RTDECL(CUFGetDefaultStream)();
EXPECT_EQ(defaultStream, nullptr);
cudaStream_t outStream = RTDECL(CUFGetDefaultStream)();
EXPECT_EQ(outStream, nullptr);
cudaStream_t stream;
cudaStreamCreate(&stream);
EXPECT_EQ(cudaSuccess, cudaGetLastError());
RTDECL(CUFSetDefaultStream)(stream);
outStream = RTDECL(CUFGetDefaultStream)();
EXPECT_EQ(outStream, stream);
}

View File

@ -51,12 +51,16 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genClusterBlockIndex(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genClusterDimBlocks(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue
genCUDASetDefaultStream(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue
genCUDASetDefaultStreamArray(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue
genCUDAGetDefaultStreamArg(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genCUDAGetDefaultStreamNull(mlir::Type,
llvm::ArrayRef<mlir::Value>);
void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>);
template <const char *fctName, int extent>
fir::ExtendedValue genLDXXFunc(mlir::Type,

View File

@ -388,11 +388,21 @@ static constexpr IntrinsicHandler cudaHandlers[]{
&CI::genCUDAGetDefaultStreamArg),
{{{"devptr", asAddr}}},
/*isElemental=*/false},
{"cudagetstreamdefaultnull",
static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
&CI::genCUDAGetDefaultStreamNull),
{},
/*isElemental=*/false},
{"cudasetstreamarray",
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
&CI::genCUDASetDefaultStreamArray),
{{{"devptr", asAddr}, {"stream", asValue}}},
/*isElemental=*/false},
{"cudasetstreamdefault",
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
&CI::genCUDASetDefaultStream),
{{{"stream", asValue}}},
/*isElemental=*/false},
{"fence_proxy_async",
static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
&CI::genFenceProxyAsync),
@ -1114,6 +1124,20 @@ CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
return res;
}
// CUDASETSTREAMDEFAULT
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStream(
mlir::Type resTy, llvm::ArrayRef<fir::ExtendedValue> args) {
assert(args.size() == 1);
mlir::Value stream = fir::getBase(args[0]);
mlir::Type i64Ty = builder.getI64Type();
auto ctx = builder.getContext();
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {});
auto funcOp =
builder.createFunction(loc, RTNAME_STRING(CUFSetDefaultStream), ftype);
auto call = fir::CallOp::create(builder, loc, funcOp, {stream});
return call.getResult(0);
}
// CUDASETSTREAMARRAY
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStreamArray(
mlir::Type resTy, llvm::ArrayRef<fir::ExtendedValue> args) {
@ -1154,6 +1178,19 @@ fir::ExtendedValue CUDAIntrinsicLibrary::genCUDAGetDefaultStreamArg(
return call.getResult(0);
}
// CUDAGETDEFAULTSTREAMNULL
mlir::Value CUDAIntrinsicLibrary::genCUDAGetDefaultStreamNull(
mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) {
assert(args.size() == 0);
mlir::Type i64Ty = builder.getI64Type();
auto ctx = builder.getContext();
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {}, {i64Ty});
auto funcOp =
builder.createFunction(loc, RTNAME_STRING(CUFGetDefaultStream), ftype);
auto call = fir::CallOp::create(builder, loc, funcOp, {});
return call.getResult(0);
}
// FENCE_PROXY_ASYNC
void CUDAIntrinsicLibrary::genFenceProxyAsync(
llvm::ArrayRef<fir::ExtendedValue> args) {

View File

@ -17,13 +17,13 @@ interface cudaforgetdefaultstream
!DIR$ IGNORE_TKR (TKR) devptr
integer, device :: devptr(*)
end function
integer(kind=cuda_stream_kind) function cudastreamgetdefaultnull()
integer(kind=cuda_stream_kind) function cudagetstreamdefaultnull()
import cuda_stream_kind
end function
end interface
interface cudaforsetdefaultstream
integer function cudasetdefaultstream(stream)
integer function cudasetstreamdefault(stream)
import cuda_stream_kind
!DIR$ IGNORE_TKR (K) stream
integer(kind=cuda_stream_kind), value :: stream