[flang][cuda] Fix return value for CUFSetDefaultStream (#181884)
The interface return an integer value but the entry point and lowering were missing it.
This commit is contained in:
parent
e0b3e82e98
commit
7ce0c53291
@ -156,8 +156,9 @@ int RTDECL(CUFSetAssociatedStream)(void *p, cudaStream_t stream) {
|
||||
return StatOk;
|
||||
}
|
||||
|
||||
void RTDECL(CUFSetDefaultStream)(cudaStream_t stream) {
|
||||
int RTDECL(CUFSetDefaultStream)(cudaStream_t stream) {
|
||||
defaultStream = stream;
|
||||
return StatOk;
|
||||
}
|
||||
|
||||
cudaStream_t RTDECL(CUFGetDefaultStream)() { return defaultStream; }
|
||||
|
||||
@ -22,7 +22,7 @@ extern "C" {
|
||||
void RTDECL(CUFRegisterAllocator)();
|
||||
cudaStream_t RTDECL(CUFGetAssociatedStream)(void *);
|
||||
int RTDECL(CUFSetAssociatedStream)(void *, cudaStream_t);
|
||||
void RTDECL(CUFSetDefaultStream)(cudaStream_t);
|
||||
int RTDECL(CUFSetDefaultStream)(cudaStream_t);
|
||||
cudaStream_t RTDECL(CUFGetDefaultStream)();
|
||||
}
|
||||
|
||||
|
||||
@ -1131,7 +1131,7 @@ fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStream(
|
||||
mlir::Value stream = fir::getBase(args[0]);
|
||||
mlir::Type i64Ty = builder.getI64Type();
|
||||
auto ctx = builder.getContext();
|
||||
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {});
|
||||
mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {resTy});
|
||||
auto funcOp =
|
||||
builder.createFunction(loc, RTNAME_STRING(CUFSetDefaultStream), ftype);
|
||||
auto call = fir::CallOp::create(builder, loc, funcOp, {stream});
|
||||
|
||||
@ -22,3 +22,19 @@ end subroutine
|
||||
! CHECK: %[[VOIDPTR:.*]] = fir.convert %[[ADDR]] : (!fir.heap<!fir.array<?xi32>>) -> !fir.llvm_ptr<i8>
|
||||
! CHECK: %[[STREAM:.*]] = fir.call @_FortranACUFGetAssociatedStream(%[[VOIDPTR]]) fastmath<contract> : (!fir.llvm_ptr<i8>) -> i64
|
||||
! CHECK: hlfir.assign %[[STREAM]] to %{{.*}}#0 : i64, !fir.ref<i64>
|
||||
|
||||
subroutine default_stream
|
||||
use cuda_runtime_api
|
||||
|
||||
integer(kind=cuda_stream_kind) :: strm, strm2
|
||||
integer :: istat
|
||||
istat = cudaStreamCreate(strm2)
|
||||
istat = cudaforSetDefaultStream(strm2)
|
||||
strm = cudaforGetDefaultStream()
|
||||
istat = cudaStreamSynchronize(cudaforGetDefaultStream())
|
||||
end subroutine
|
||||
|
||||
! CHECK-LABEL: func.func @_QPdefault_stream()
|
||||
! CHECK: %{{.*}} = fir.call @_FortranACUFSetDefaultStream(%{{.*}}) fastmath<contract> : (i64) -> i32
|
||||
! CHECK: %{{.*}} = fir.call @_FortranACUFGetDefaultStream() fastmath<contract> : () -> i64
|
||||
! CHECK: %{{.*}} = fir.call @_FortranACUFGetDefaultStream() fastmath<contract> : () -> i64
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user