[flang][cuda] Fix return value for CUFSetDefaultStream (#181884)

The interface return an integer value but the entry point and lowering
were missing it.
This commit is contained in:
Valentin Clement (バレンタイン クレメン) 2026-02-17 13:48:29 -08:00 committed by GitHub
parent e0b3e82e98
commit 7ce0c53291
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 20 additions and 3 deletions

View File

@ -156,8 +156,9 @@ int RTDECL(CUFSetAssociatedStream)(void *p, cudaStream_t stream) {
return StatOk;
}
void RTDECL(CUFSetDefaultStream)(cudaStream_t stream) {
int RTDECL(CUFSetDefaultStream)(cudaStream_t stream) {
defaultStream = stream;
return StatOk;
}
cudaStream_t RTDECL(CUFGetDefaultStream)() { return defaultStream; }

View File

@ -22,7 +22,7 @@ extern "C" {
void RTDECL(CUFRegisterAllocator)();
cudaStream_t RTDECL(CUFGetAssociatedStream)(void *);
int RTDECL(CUFSetAssociatedStream)(void *, cudaStream_t);
void RTDECL(CUFSetDefaultStream)(cudaStream_t);
int RTDECL(CUFSetDefaultStream)(cudaStream_t);
cudaStream_t RTDECL(CUFGetDefaultStream)();
}

View File

@ -1131,7 +1131,7 @@ fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStream(
mlir::Value stream = fir::getBase(args[0]);
mlir::Type i64Ty = builder.getI64Type();
auto ctx = builder.getContext();
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {});
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {resTy});
auto funcOp =
builder.createFunction(loc, RTNAME_STRING(CUFSetDefaultStream), ftype);
auto call = fir::CallOp::create(builder, loc, funcOp, {stream});

View File

@ -22,3 +22,19 @@ end subroutine
! CHECK: %[[VOIDPTR:.*]] = fir.convert %[[ADDR]] : (!fir.heap<!fir.array<?xi32>>) -> !fir.llvm_ptr<i8>
! CHECK: %[[STREAM:.*]] = fir.call @_FortranACUFGetAssociatedStream(%[[VOIDPTR]]) fastmath<contract> : (!fir.llvm_ptr<i8>) -> i64
! CHECK: hlfir.assign %[[STREAM]] to %{{.*}}#0 : i64, !fir.ref<i64>
subroutine default_stream
use cuda_runtime_api
integer(kind=cuda_stream_kind) :: strm, strm2
integer :: istat
istat = cudaStreamCreate(strm2)
istat = cudaforSetDefaultStream(strm2)
strm = cudaforGetDefaultStream()
istat = cudaStreamSynchronize(cudaforGetDefaultStream())
end subroutine
! CHECK-LABEL: func.func @_QPdefault_stream()
! CHECK: %{{.*}} = fir.call @_FortranACUFSetDefaultStream(%{{.*}}) fastmath<contract> : (i64) -> i32
! CHECK: %{{.*}} = fir.call @_FortranACUFGetDefaultStream() fastmath<contract> : () -> i64
! CHECK: %{{.*}} = fir.call @_FortranACUFGetDefaultStream() fastmath<contract> : () -> i64