[Offload] Add global variable address/size queries (#147972)

Add two new symbol info types for getting the bounds of a global
variable. As well as a number of tests for reading/writing to it.
This commit is contained in:
Ross Brunton 2025-07-11 16:12:48 +01:00 committed by GitHub
parent 2c0d563a76
commit 2fdeeefacf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 175 additions and 3 deletions

View File

@ -39,7 +39,9 @@ def : Enum {
let desc = "Supported symbol info.";
let is_typed = 1;
let etors = [
TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">
TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">,
TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for this global variable.">,
TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this global variable.">,
];
}

View File

@ -753,9 +753,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
void *PropValue, size_t *PropSizeRet) {
InfoWriter Info(PropSize, PropValue, PropSizeRet);
auto CheckKind = [&](ol_symbol_kind_t Required) {
if (Symbol->Kind != Required) {
std::string ErrBuffer;
llvm::raw_string_ostream(ErrBuffer)
<< PropName << ": Expected a symbol of Kind " << Required
<< " but given a symbol of Kind " << Symbol->Kind;
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
}
return Plugin::success();
};
switch (PropName) {
case OL_SYMBOL_INFO_KIND:
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
return Err;
return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
return Err;
return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"olGetSymbolInfo enum '%i' is invalid", PropName);

View File

@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_
if (Type == "char[]") {
OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
} else {
OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
Type);
if (Type == "void *")
OS << formatv(TAB_2 "void * const * const tptr = (void * "
"const * const)ptr;\n");
else
OS << formatv(
TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type);
// TODO: Handle other cases here
OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
if (Type.ends_with("*")) {

View File

@ -13,6 +13,32 @@
using olMemcpyTest = OffloadQueueTest;
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest);
struct olMemcpyGlobalTest : OffloadGlobalTest {
void SetUp() override {
RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp());
ASSERT_SUCCESS(
olGetSymbol(Program, "read", OL_SYMBOL_KIND_KERNEL, &ReadKernel));
ASSERT_SUCCESS(
olGetSymbol(Program, "write", OL_SYMBOL_KIND_KERNEL, &WriteKernel));
ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
ASSERT_SUCCESS(olGetSymbolInfo(
Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr));
LaunchArgs.Dimensions = 1;
LaunchArgs.GroupSize = {64, 1, 1};
LaunchArgs.NumGroups = {1, 1, 1};
LaunchArgs.DynSharedMemory = 0;
}
ol_kernel_launch_size_args_t LaunchArgs{};
void *Addr;
ol_symbol_handle_t ReadKernel;
ol_symbol_handle_t WriteKernel;
ol_queue_handle_t Queue;
};
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest);
TEST_P(olMemcpyTest, SuccessHtoD) {
constexpr size_t Size = 1024;
void *Alloc;
@ -105,3 +131,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) {
ASSERT_SUCCESS(
olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
}
TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
void *SourceMem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
64 * sizeof(uint32_t), &SourceMem));
uint32_t *SourceData = (uint32_t *)SourceMem;
for (auto I = 0; I < 64; I++)
SourceData[I] = I;
void *DestMem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
64 * sizeof(uint32_t), &DestMem));
ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
64 * sizeof(uint32_t), nullptr));
ASSERT_SUCCESS(olWaitQueue(Queue));
ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
64 * sizeof(uint32_t), nullptr));
ASSERT_SUCCESS(olWaitQueue(Queue));
uint32_t *DestData = (uint32_t *)DestMem;
for (uint32_t I = 0; I < 64; I++)
ASSERT_EQ(DestData[I], I);
ASSERT_SUCCESS(olMemFree(DestMem));
ASSERT_SUCCESS(olMemFree(SourceMem));
}
TEST_P(olMemcpyGlobalTest, SuccessWrite) {
void *SourceMem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
LaunchArgs.GroupSize.x * sizeof(uint32_t),
&SourceMem));
uint32_t *SourceData = (uint32_t *)SourceMem;
for (auto I = 0; I < 64; I++)
SourceData[I] = I;
void *DestMem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
LaunchArgs.GroupSize.x * sizeof(uint32_t),
&DestMem));
struct {
void *Mem;
} Args{DestMem};
ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
64 * sizeof(uint32_t), nullptr));
ASSERT_SUCCESS(olWaitQueue(Queue));
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
&LaunchArgs, nullptr));
ASSERT_SUCCESS(olWaitQueue(Queue));
uint32_t *DestData = (uint32_t *)DestMem;
for (uint32_t I = 0; I < 64; I++)
ASSERT_EQ(DestData[I], I);
ASSERT_SUCCESS(olMemFree(DestMem));
ASSERT_SUCCESS(olMemFree(SourceMem));
}
TEST_P(olMemcpyGlobalTest, SuccessRead) {
void *DestMem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
LaunchArgs.GroupSize.x * sizeof(uint32_t),
&DestMem));
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
&LaunchArgs, nullptr));
ASSERT_SUCCESS(olWaitQueue(Queue));
ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
64 * sizeof(uint32_t), nullptr));
ASSERT_SUCCESS(olWaitQueue(Queue));
uint32_t *DestData = (uint32_t *)DestMem;
for (uint32_t I = 0; I < 64; I++)
ASSERT_EQ(DestData[I], I * 2);
ASSERT_SUCCESS(olMemFree(DestMem));
}

View File

@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) {
ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE);
}
TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) {
void *RetrievedAddr;
ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
sizeof(RetrievedAddr), &RetrievedAddr));
}
TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) {
void *RetrievedAddr = nullptr;
ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
sizeof(RetrievedAddr), &RetrievedAddr));
ASSERT_NE(RetrievedAddr, nullptr);
}
TEST_P(olGetSymbolInfoKernelTest, InvalidSize) {
size_t RetrievedSize;
ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
sizeof(RetrievedSize), &RetrievedSize));
}
TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) {
size_t RetrievedSize = 0;
ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
sizeof(RetrievedSize), &RetrievedSize));
ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t));
}
TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) {
ol_symbol_kind_t RetrievedKind;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

View File

@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) {
ASSERT_EQ(Size, sizeof(ol_symbol_kind_t));
}
TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) {
size_t Size = 0;
ASSERT_SUCCESS(olGetSymbolInfoSize(
Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size));
ASSERT_EQ(Size, sizeof(void *));
}
TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) {
size_t Size = 0;
ASSERT_SUCCESS(
olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size));
ASSERT_EQ(Size, sizeof(size_t));
}
TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) {
size_t Size = 0;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,