diff --git a/flang-rt/lib/cuda/allocator.cpp b/flang-rt/lib/cuda/allocator.cpp index 795b26b64172..347195926def 100644 --- a/flang-rt/lib/cuda/allocator.cpp +++ b/flang-rt/lib/cuda/allocator.cpp @@ -156,8 +156,9 @@ int RTDECL(CUFSetAssociatedStream)(void *p, cudaStream_t stream) { return StatOk; } -void RTDECL(CUFSetDefaultStream)(cudaStream_t stream) { +int RTDECL(CUFSetDefaultStream)(cudaStream_t stream) { defaultStream = stream; + return StatOk; } cudaStream_t RTDECL(CUFGetDefaultStream)() { return defaultStream; } diff --git a/flang/include/flang/Runtime/CUDA/allocator.h b/flang/include/flang/Runtime/CUDA/allocator.h index d21f7d7c421b..c45b97a6df4f 100644 --- a/flang/include/flang/Runtime/CUDA/allocator.h +++ b/flang/include/flang/Runtime/CUDA/allocator.h @@ -22,7 +22,7 @@ extern "C" { void RTDECL(CUFRegisterAllocator)(); cudaStream_t RTDECL(CUFGetAssociatedStream)(void *); int RTDECL(CUFSetAssociatedStream)(void *, cudaStream_t); -void RTDECL(CUFSetDefaultStream)(cudaStream_t); +int RTDECL(CUFSetDefaultStream)(cudaStream_t); cudaStream_t RTDECL(CUFGetDefaultStream)(); } diff --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp index 4986c5704808..bbc353634cd4 100644 --- a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp +++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp @@ -1131,7 +1131,7 @@ fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStream( mlir::Value stream = fir::getBase(args[0]); mlir::Type i64Ty = builder.getI64Type(); auto ctx = builder.getContext(); - mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {}); + mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {resTy}); auto funcOp = builder.createFunction(loc, RTNAME_STRING(CUFSetDefaultStream), ftype); auto call = fir::CallOp::create(builder, loc, funcOp, {stream}); diff --git a/flang/test/Lower/CUDA/cuda-default-stream.cuf b/flang/test/Lower/CUDA/cuda-default-stream.cuf index 59c6bc6b7061..beacb409f44f 100644 --- a/flang/test/Lower/CUDA/cuda-default-stream.cuf +++ b/flang/test/Lower/CUDA/cuda-default-stream.cuf @@ -22,3 +22,19 @@ end subroutine ! CHECK: %[[VOIDPTR:.*]] = fir.convert %[[ADDR]] : (!fir.heap>) -> !fir.llvm_ptr ! CHECK: %[[STREAM:.*]] = fir.call @_FortranACUFGetAssociatedStream(%[[VOIDPTR]]) fastmath : (!fir.llvm_ptr) -> i64 ! CHECK: hlfir.assign %[[STREAM]] to %{{.*}}#0 : i64, !fir.ref + +subroutine default_stream + use cuda_runtime_api + + integer(kind=cuda_stream_kind) :: strm, strm2 + integer :: istat + istat = cudaStreamCreate(strm2) + istat = cudaforSetDefaultStream(strm2) + strm = cudaforGetDefaultStream() + istat = cudaStreamSynchronize(cudaforGetDefaultStream()) +end subroutine + +! CHECK-LABEL: func.func @_QPdefault_stream() +! CHECK: %{{.*}} = fir.call @_FortranACUFSetDefaultStream(%{{.*}}) fastmath : (i64) -> i32 +! CHECK: %{{.*}} = fir.call @_FortranACUFGetDefaultStream() fastmath : () -> i64 +! CHECK: %{{.*}} = fir.call @_FortranACUFGetDefaultStream() fastmath : () -> i64