diff --git a/flang-rt/lib/cuda/registration.cpp b/flang-rt/lib/cuda/registration.cpp index 60b0e491b6ff..93ea36d8a6a7 100644 --- a/flang-rt/lib/cuda/registration.cpp +++ b/flang-rt/lib/cuda/registration.cpp @@ -24,6 +24,9 @@ extern void __cudaRegisterFunction(void **fatCubinHandle, const char *hostFun, extern void __cudaRegisterVar(void **fatCubinHandle, char *hostVar, const char *deviceAddress, const char *deviceName, int ext, size_t size, int constant, int global); +extern void __cudaRegisterManagedVar(void **fatCubinHandle, + void **hostVarPtrAddress, char *deviceAddress, const char *deviceName, + int ext, size_t size, int constant, int global); void *RTDECL(CUFRegisterModule)(void *data) { void **fatHandle{__cudaRegisterFatBinary(data)}; @@ -42,6 +45,11 @@ void RTDEF(CUFRegisterVariable)( __cudaRegisterVar(module, varSym, varName, varName, 0, size, 0, 0); } +void RTDEF(CUFRegisterManagedVariable)( + void **module, void **varSym, const char *varName, int64_t size) { + __cudaRegisterManagedVar(module, varSym, varName, varName, 0, size, 0, 0); +} + } // extern "C" } // namespace Fortran::runtime::cuda diff --git a/flang/include/flang/Runtime/CUDA/registration.h b/flang/include/flang/Runtime/CUDA/registration.h index 5237069a4c73..b322bb9362f2 100644 --- a/flang/include/flang/Runtime/CUDA/registration.h +++ b/flang/include/flang/Runtime/CUDA/registration.h @@ -28,6 +28,10 @@ void RTDECL(CUFRegisterFunction)( void RTDECL(CUFRegisterVariable)( void **module, char *varSym, const char *varName, int64_t size); +/// Register a managed variable. +void RTDECL(CUFRegisterManagedVariable)( + void **module, void **varSym, const char *varName, int64_t size); + } // extern "C" } // namespace Fortran::runtime::cuda