[Offload] Add global variable address/size queries (#147972)
Add two new symbol info types for getting the bounds of a global variable. As well as a number of tests for reading/writing to it.
This commit is contained in:
parent
2c0d563a76
commit
2fdeeefacf
@ -39,7 +39,9 @@ def : Enum {
|
||||
let desc = "Supported symbol info.";
|
||||
let is_typed = 1;
|
||||
let etors = [
|
||||
TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">
|
||||
TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">,
|
||||
TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for this global variable.">,
|
||||
TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this global variable.">,
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -753,9 +753,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
|
||||
void *PropValue, size_t *PropSizeRet) {
|
||||
InfoWriter Info(PropSize, PropValue, PropSizeRet);
|
||||
|
||||
auto CheckKind = [&](ol_symbol_kind_t Required) {
|
||||
if (Symbol->Kind != Required) {
|
||||
std::string ErrBuffer;
|
||||
llvm::raw_string_ostream(ErrBuffer)
|
||||
<< PropName << ": Expected a symbol of Kind " << Required
|
||||
<< " but given a symbol of Kind " << Symbol->Kind;
|
||||
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
|
||||
}
|
||||
return Plugin::success();
|
||||
};
|
||||
|
||||
switch (PropName) {
|
||||
case OL_SYMBOL_INFO_KIND:
|
||||
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
|
||||
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
|
||||
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
|
||||
return Err;
|
||||
return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
|
||||
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
|
||||
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
|
||||
return Err;
|
||||
return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
|
||||
default:
|
||||
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
|
||||
"olGetSymbolInfo enum '%i' is invalid", PropName);
|
||||
|
@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_
|
||||
if (Type == "char[]") {
|
||||
OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
|
||||
} else {
|
||||
OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
|
||||
Type);
|
||||
if (Type == "void *")
|
||||
OS << formatv(TAB_2 "void * const * const tptr = (void * "
|
||||
"const * const)ptr;\n");
|
||||
else
|
||||
OS << formatv(
|
||||
TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type);
|
||||
// TODO: Handle other cases here
|
||||
OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
|
||||
if (Type.ends_with("*")) {
|
||||
|
@ -13,6 +13,32 @@
|
||||
using olMemcpyTest = OffloadQueueTest;
|
||||
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest);
|
||||
|
||||
struct olMemcpyGlobalTest : OffloadGlobalTest {
|
||||
void SetUp() override {
|
||||
RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp());
|
||||
ASSERT_SUCCESS(
|
||||
olGetSymbol(Program, "read", OL_SYMBOL_KIND_KERNEL, &ReadKernel));
|
||||
ASSERT_SUCCESS(
|
||||
olGetSymbol(Program, "write", OL_SYMBOL_KIND_KERNEL, &WriteKernel));
|
||||
ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
|
||||
ASSERT_SUCCESS(olGetSymbolInfo(
|
||||
Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr));
|
||||
|
||||
LaunchArgs.Dimensions = 1;
|
||||
LaunchArgs.GroupSize = {64, 1, 1};
|
||||
LaunchArgs.NumGroups = {1, 1, 1};
|
||||
|
||||
LaunchArgs.DynSharedMemory = 0;
|
||||
}
|
||||
|
||||
ol_kernel_launch_size_args_t LaunchArgs{};
|
||||
void *Addr;
|
||||
ol_symbol_handle_t ReadKernel;
|
||||
ol_symbol_handle_t WriteKernel;
|
||||
ol_queue_handle_t Queue;
|
||||
};
|
||||
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest);
|
||||
|
||||
TEST_P(olMemcpyTest, SuccessHtoD) {
|
||||
constexpr size_t Size = 1024;
|
||||
void *Alloc;
|
||||
@ -105,3 +131,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) {
|
||||
ASSERT_SUCCESS(
|
||||
olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
|
||||
}
|
||||
|
||||
TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
|
||||
void *SourceMem;
|
||||
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
|
||||
64 * sizeof(uint32_t), &SourceMem));
|
||||
uint32_t *SourceData = (uint32_t *)SourceMem;
|
||||
for (auto I = 0; I < 64; I++)
|
||||
SourceData[I] = I;
|
||||
|
||||
void *DestMem;
|
||||
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
|
||||
64 * sizeof(uint32_t), &DestMem));
|
||||
|
||||
ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
|
||||
64 * sizeof(uint32_t), nullptr));
|
||||
ASSERT_SUCCESS(olWaitQueue(Queue));
|
||||
ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
|
||||
64 * sizeof(uint32_t), nullptr));
|
||||
ASSERT_SUCCESS(olWaitQueue(Queue));
|
||||
|
||||
uint32_t *DestData = (uint32_t *)DestMem;
|
||||
for (uint32_t I = 0; I < 64; I++)
|
||||
ASSERT_EQ(DestData[I], I);
|
||||
|
||||
ASSERT_SUCCESS(olMemFree(DestMem));
|
||||
ASSERT_SUCCESS(olMemFree(SourceMem));
|
||||
}
|
||||
|
||||
TEST_P(olMemcpyGlobalTest, SuccessWrite) {
|
||||
void *SourceMem;
|
||||
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
|
||||
LaunchArgs.GroupSize.x * sizeof(uint32_t),
|
||||
&SourceMem));
|
||||
uint32_t *SourceData = (uint32_t *)SourceMem;
|
||||
for (auto I = 0; I < 64; I++)
|
||||
SourceData[I] = I;
|
||||
|
||||
void *DestMem;
|
||||
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
|
||||
LaunchArgs.GroupSize.x * sizeof(uint32_t),
|
||||
&DestMem));
|
||||
struct {
|
||||
void *Mem;
|
||||
} Args{DestMem};
|
||||
|
||||
ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
|
||||
64 * sizeof(uint32_t), nullptr));
|
||||
ASSERT_SUCCESS(olWaitQueue(Queue));
|
||||
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
|
||||
&LaunchArgs, nullptr));
|
||||
ASSERT_SUCCESS(olWaitQueue(Queue));
|
||||
|
||||
uint32_t *DestData = (uint32_t *)DestMem;
|
||||
for (uint32_t I = 0; I < 64; I++)
|
||||
ASSERT_EQ(DestData[I], I);
|
||||
|
||||
ASSERT_SUCCESS(olMemFree(DestMem));
|
||||
ASSERT_SUCCESS(olMemFree(SourceMem));
|
||||
}
|
||||
|
||||
TEST_P(olMemcpyGlobalTest, SuccessRead) {
|
||||
void *DestMem;
|
||||
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
|
||||
LaunchArgs.GroupSize.x * sizeof(uint32_t),
|
||||
&DestMem));
|
||||
|
||||
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
|
||||
&LaunchArgs, nullptr));
|
||||
ASSERT_SUCCESS(olWaitQueue(Queue));
|
||||
ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
|
||||
64 * sizeof(uint32_t), nullptr));
|
||||
ASSERT_SUCCESS(olWaitQueue(Queue));
|
||||
|
||||
uint32_t *DestData = (uint32_t *)DestMem;
|
||||
for (uint32_t I = 0; I < 64; I++)
|
||||
ASSERT_EQ(DestData[I], I * 2);
|
||||
|
||||
ASSERT_SUCCESS(olMemFree(DestMem));
|
||||
}
|
||||
|
@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) {
|
||||
ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE);
|
||||
}
|
||||
|
||||
TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) {
|
||||
void *RetrievedAddr;
|
||||
ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
|
||||
olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
|
||||
sizeof(RetrievedAddr), &RetrievedAddr));
|
||||
}
|
||||
|
||||
TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) {
|
||||
void *RetrievedAddr = nullptr;
|
||||
ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
|
||||
sizeof(RetrievedAddr), &RetrievedAddr));
|
||||
ASSERT_NE(RetrievedAddr, nullptr);
|
||||
}
|
||||
|
||||
TEST_P(olGetSymbolInfoKernelTest, InvalidSize) {
|
||||
size_t RetrievedSize;
|
||||
ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
|
||||
olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
|
||||
sizeof(RetrievedSize), &RetrievedSize));
|
||||
}
|
||||
|
||||
TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) {
|
||||
size_t RetrievedSize = 0;
|
||||
ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
|
||||
sizeof(RetrievedSize), &RetrievedSize));
|
||||
ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) {
|
||||
ol_symbol_kind_t RetrievedKind;
|
||||
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
|
||||
|
@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) {
|
||||
ASSERT_EQ(Size, sizeof(ol_symbol_kind_t));
|
||||
}
|
||||
|
||||
TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) {
|
||||
size_t Size = 0;
|
||||
ASSERT_SUCCESS(olGetSymbolInfoSize(
|
||||
Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size));
|
||||
ASSERT_EQ(Size, sizeof(void *));
|
||||
}
|
||||
|
||||
TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) {
|
||||
size_t Size = 0;
|
||||
ASSERT_SUCCESS(
|
||||
olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size));
|
||||
ASSERT_EQ(Size, sizeof(size_t));
|
||||
}
|
||||
|
||||
TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) {
|
||||
size_t Size = 0;
|
||||
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
|
||||
|
Loading…
x
Reference in New Issue
Block a user