diff --git a/libc/docs/gpu/rpc.rst b/libc/docs/gpu/rpc.rst index 4ac3786cfa08..c87d60640794 100644 --- a/libc/docs/gpu/rpc.rst +++ b/libc/docs/gpu/rpc.rst @@ -113,10 +113,10 @@ done. It can be omitted if asynchronous execution is desired. void rpc_host_call(void *fn, void *data, size_t size) { rpc::Client::Port port = rpc::client.open(); port.send_n(data, size); - port.send([=](rpc::Buffer *buffer) { + port.send([=](rpc::Buffer *buffer, uint32_t) { buffer->data[0] = reinterpret_cast(fn); }); - port.recv([](rpc::Buffer *) {}); + port.recv([](rpc::Buffer *, uint32_t) {}); port.close(); } @@ -131,7 +131,7 @@ call a function pointer provided by the client. In this example, the server simply runs forever in a separate thread for brevity's sake. Because the client is a GPU potentially handling several threads at once, the server needs to loop over all the active threads on the GPU. We -abstract this into the ``lane_size`` variable, which is simply the device's warp +abstract this into the ``num_lanes`` variable, which is simply the device's warp or wavefront size. The identifier is simply the threads index into the current warp or wavefront. We allocate memory to copy the struct data into, and then call the given function pointer with that copied data. The final send simply @@ -147,8 +147,8 @@ data. switch(port->get_opcode()) { case RPC_HOST_CALL: { - uint64_t sizes[LANE_SIZE]; - void *args[LANE_SIZE]; + uint64_t sizes[NUM_LANES]; + void *args[NUM_LANES]; port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; }); port->recv([&](rpc::Buffer *buffer, uint32_t id) { reinterpret_cast(buffer->data[0])(args[id]); @@ -162,8 +162,47 @@ data. port->recv([](rpc::Buffer *) {}); break; } + port->close(); } +Function Dispatch +----------------- + +There are cases where the client wants to simply execute a function as-is on the +server. A helper function is provided to make this case almost automatic. By +default, all memory is assumed to live privately on the client. Pointer +arguments will be copied between the client and server for correctness. Pointers +to void will all be treated as opaque pointers and copied by-value. Constant +character pointers will be treated as C-strings and copied using its length. +Functions returning void will wait for the server to complete execution rather +than submitting asynchronously. + +.. code-block:: c++ + + double fn(int x, long y, char c, double d); + + // Client-side dispatch. + double fn(int x, long y, char c, double d) { + return rpc::dispatch(client, fn, x, y, c, d); + } + + // Server-side handling. + for(;;) { + auto port = server.try_open(index); + if (!port) + return continue; + + switch(port->get_opcode()) { + case OPCODE: + rpc::invoke(fn, *port); + default: + port->recv([](rpc::Buffer *) {}); + break; + } + port->close(); + } + + CUDA Server Example ------------------- diff --git a/libc/shared/rpc.h b/libc/shared/rpc.h index dac2a7949a90..89f716a0d1ac 100644 --- a/libc/shared/rpc.h +++ b/libc/shared/rpc.h @@ -318,10 +318,19 @@ public: template RPC_ATTRS void recv_n(void **dst, uint64_t *size, A &&alloc); + template RPC_ATTRS void send_n(const Ty *src); + template RPC_ATTRS void recv_n(Ty *dst); + RPC_ATTRS uint32_t get_opcode() const { return process.header[index].opcode; } RPC_ATTRS uint32_t get_index() const { return index; } + RPC_ATTRS uint64_t get_lane_mask() const { + if constexpr (T) + return process.header[index].mask; + return lane_mask; + } + RPC_ATTRS void close() { // Wait for all lanes to finish using the port. rpc::sync_lane(lane_mask); @@ -392,7 +401,7 @@ template template RPC_ATTRS void Port::send(F fill) { process.wait_for_ownership(lane_mask, index, out, in); // Apply the \p fill function to initialize the buffer and release the memory. - invoke_rpc(fill, lane_size, process.header[index].mask, + invoke_rpc(fill, lane_size, get_lane_mask(), process.get_packet(index, lane_size)); out = process.invert_outbox(index, out); owns_buffer = false; @@ -414,7 +423,7 @@ template template RPC_ATTRS void Port::recv(U use) { process.wait_for_ownership(lane_mask, index, out, in); // Apply the \p use function to read the memory out of the buffer. - invoke_rpc(use, lane_size, process.header[index].mask, + invoke_rpc(use, lane_size, get_lane_mask(), process.get_packet(index, lane_size)); receive = true; owns_buffer = true; @@ -509,6 +518,30 @@ RPC_ATTRS void Port::recv_n(void **dst, uint64_t *size, A &&alloc) { } } +/// Simplified version of `send_n` where the size is a known constant. +template +template +RPC_ATTRS void Port::send_n(const Ty *src) { + for (uint64_t idx = 0; idx < sizeof(Ty); idx += sizeof(Buffer::data)) { + const uint64_t bytes = rpc::min(sizeof(Ty) - idx, sizeof(Buffer::data)); + send([&](Buffer *buffer, uint32_t id) { + rpc_memcpy(buffer->data, advance(&lane_value(src, id), idx), bytes); + }); + } +} + +/// Simplified version of `recv_n` where the size is a known constant. +template +template +RPC_ATTRS void Port::recv_n(Ty *dst) { + for (uint64_t idx = 0; idx < sizeof(Ty); idx += sizeof(Buffer::data)) { + const uint64_t bytes = rpc::min(sizeof(Ty) - idx, sizeof(Buffer::data)); + recv([&](Buffer *buffer, uint32_t id) { + rpc_memcpy(advance(&lane_value(dst, id), idx), buffer->data, bytes); + }); + } +} + /// Continually attempts to open a port to use as the client. The client can /// only open a port if we find an index that is in a valid sending state. That /// is, there are send operations pending that haven't been serviced on this @@ -590,7 +623,6 @@ RPC_ATTRS Server::Port Server::open(uint32_t lane_size) { } } -#undef RPC_ATTRS #if !__has_builtin(__scoped_atomic_load_n) #undef __scoped_atomic_load_n #undef __scoped_atomic_store_n diff --git a/libc/shared/rpc_dispatch.h b/libc/shared/rpc_dispatch.h new file mode 100644 index 000000000000..e95f82496522 --- /dev/null +++ b/libc/shared/rpc_dispatch.h @@ -0,0 +1,258 @@ +//===-- Helper functions for client / server dispatch -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "rpc.h" +#include "rpc_util.h" + +namespace rpc { +namespace { + +// Forward declarations needed for the server, we assume these are present. +extern "C" void *malloc(uint64_t); +extern "C" void free(void *); + +// Traits to convert between a tuple and binary representation of an argument +// list. +template struct tuple_bytes { + static constexpr uint64_t SIZE = rpc::max(1ul, (0 + ... + sizeof(Ts))); + using array_type = rpc::array; + + template + RPC_ATTRS static constexpr array_type pack_impl(rpc::tuple t, + rpc::index_sequence) { + array_type out{}; + uint8_t *p = out.data(); + ((rpc::rpc_memcpy(p, &rpc::get(t), sizeof(Ts)), p += sizeof(Ts)), ...); + return out; + } + + RPC_ATTRS static constexpr array_type pack(rpc::tuple t) { + return pack_impl(t, rpc::index_sequence_for{}); + } + + template + RPC_ATTRS static constexpr rpc::tuple + unpack_impl(const uint8_t *data, rpc::index_sequence) { + rpc::tuple t{}; + const uint8_t *p = data; + ((rpc::rpc_memcpy(&rpc::get(t), p, sizeof(Ts)), p += sizeof(Ts)), ...); + return t; + } + + RPC_ATTRS static constexpr rpc::tuple unpack(const array_type &a) { + return unpack_impl(a.data(), rpc::index_sequence_for{}); + } +}; +template +struct tuple_bytes> : tuple_bytes {}; + +// Client-side dispatch of pointer values. We copy the memory associated with +// the pointer to the server and recieve back a server-side pointer to replace +// the client-side pointer in the argument list. +template +RPC_ATTRS constexpr void prepare_arg(rpc::Client::Port &port, Tuple &t) { + using ArgTy = rpc::tuple_element_t; + if constexpr (rpc::is_pointer_v && + !rpc::is_void_v>) { + // We assume all constant character arrays are C-strings. + uint64_t size{}; + if constexpr (rpc::is_same_v) + size = rpc::string_length(rpc::get(t)); + else + size = sizeof(rpc::remove_pointer_t); + port.send_n(rpc::get(t), size); + port.recv([&](rpc::Buffer *buffer, uint32_t) { + rpc::get(t) = *reinterpret_cast(buffer->data); + }); + } +} + +// Server-side handling of pointer arguments. We recieve the memory into a +// temporary buffer and pass a pointer to this new memory back to the client. +template +RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port) { + using ArgTy = rpc::tuple_element_t; + if constexpr (rpc::is_pointer_v && + !rpc::is_void_v>) { + void *args[NUM_LANES]{}; + uint64_t sizes[NUM_LANES]{}; + port.recv_n(args, sizes, [](uint64_t size) { + if constexpr (rpc::is_same_v) + return malloc(size); + else + return malloc( + sizeof(rpc::remove_const_t>)); + }); + port.send([&](rpc::Buffer *buffer, uint32_t id) { + *reinterpret_cast(buffer->data) = static_cast(args[id]); + }); + } +} + +// Client-side finalization of pointer arguments. If the type is not constant we +// must copy back any potential modifications the invoked function made to that +// pointer. +template +RPC_ATTRS constexpr void finish_arg(rpc::Client::Port &port, Tuple &t) { + using ArgTy = rpc::tuple_element_t; + using MemoryTy = rpc::remove_const_t> *; + if constexpr (rpc::is_pointer_v && !rpc::is_const_v && + !rpc::is_void_v>) { + uint64_t size{}; + void *buf{}; + port.recv_n(&buf, &size, [&](uint64_t) { + return const_cast(rpc::get(t)); + }); + } +} + +// Server-side finalization of pointer arguments. We copy any potential +// modifications to the value back to the client unless it was a constant. We +// can also free the associated memory. +template +RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port, + Tuple (&t)[NUM_LANES]) { + using ArgTy = rpc::tuple_element_t; + if constexpr (rpc::is_pointer_v && !rpc::is_const_v && + !rpc::is_void_v>) { + const void *buffer[NUM_LANES]{}; + size_t sizes[NUM_LANES]{}; + for (uint32_t id = 0; id < NUM_LANES; ++id) { + if (port.get_lane_mask() & (uint64_t(1) << id)) { + buffer[id] = rpc::get(t[id]); + sizes[id] = sizeof(rpc::remove_pointer_t); + } + } + port.send_n(buffer, sizes); + } + + if constexpr (rpc::is_pointer_v && + !rpc::is_void_v>) { + for (uint32_t id = 0; id < NUM_LANES; ++id) { + if (port.get_lane_mask() & (uint64_t(1) << id)) + free(const_cast( + static_cast(rpc::get(t[id])))); + } + } +} + +// Iterate over the tuple list of arguments to see if we need to forward any. +// The current forwarding is somewhat inefficient as each pointer is an +// individual RPC call. +template +RPC_ATTRS constexpr void prepare_args(rpc::Client::Port &port, Tuple &t, + rpc::index_sequence) { + (prepare_arg(port, t), ...); +} +template +RPC_ATTRS constexpr void prepare_args(rpc::Server::Port &port, + rpc::index_sequence) { + (prepare_arg(port), ...); +} + +// Performs the preparation in reverse, copying back any modified values. +template +RPC_ATTRS constexpr void finish_args(rpc::Client::Port &port, Tuple &&t, + rpc::index_sequence) { + (finish_arg(port, t), ...); +} +template +RPC_ATTRS constexpr void finish_args(rpc::Server::Port &port, + Tuple (&t)[NUM_LANES], + rpc::index_sequence) { + (finish_arg(port, t), ...); +} +} // namespace + +// Dispatch a function call to the server through the RPC mechanism. Copies the +// argument list through the RPC interface. +template +RPC_ATTRS constexpr typename function_traits::return_type +dispatch(rpc::Client &client, FnTy, CallArgs... args) { + using Traits = function_traits; + using RetTy = typename Traits::return_type; + using TupleTy = typename Traits::arg_types; + using Bytes = tuple_bytes; + + static_assert(sizeof...(CallArgs) == Traits::ARITY, + "Argument count mismatch"); + static_assert(((rpc::is_trivially_constructible_v && + rpc::is_trivially_copyable_v) && + ...), + "Must be a trivial type"); + + auto port = client.open(); + + // Copy over any pointer arguments by walking the argument list. + TupleTy arg_tuple{rpc::forward(args)...}; + rpc::prepare_args(port, arg_tuple, rpc::make_index_sequence{}); + + // Compress the argument list to a binary stream and send it to the server. + auto bytes = Bytes::pack(arg_tuple); + port.send_n(&bytes); + + // Copy back any potentially modified pointer arguments and the return value. + rpc::finish_args(port, TupleTy{rpc::forward(args)...}, + rpc::make_index_sequence{}); + + // Copy back the final function return value. + using BufferTy = rpc::conditional_t, uint8_t, RetTy>; + BufferTy ret{}; + port.recv_n(&ret); + port.close(); + + if constexpr (!rpc::is_void_v) + return ret; +} + +// Invoke a function on the server on behalf of the client. Recieves the +// arguments through the interface and forwards them to the function. +template +RPC_ATTRS constexpr void invoke(FnTy fn, rpc::Server::Port &port) { + using Traits = function_traits; + using RetTy = typename Traits::return_type; + using TupleTy = typename Traits::arg_types; + using Bytes = tuple_bytes; + + // Recieve pointer data from the host and pack it in server-side memory. + rpc::prepare_args( + port, rpc::make_index_sequence{}); + + // Get the argument list from the client. + typename Bytes::array_type arg_buf[NUM_LANES]{}; + port.recv_n(arg_buf); + + // Convert the recieved arguments into an invocable argument list. + TupleTy args[NUM_LANES]; + for (uint32_t id = 0; id < NUM_LANES; ++id) { + if (port.get_lane_mask() & (uint64_t(1) << id)) + args[id] = Bytes::unpack(arg_buf[id]); + } + + // Execute the function with the provided arguments and send back any copies + // made for pointer data. + using BufferTy = rpc::conditional_t, uint8_t, RetTy>; + BufferTy rets[NUM_LANES]{}; + for (uint32_t id = 0; id < NUM_LANES; ++id) { + if (port.get_lane_mask() & (uint64_t(1) << id)) { + if constexpr (rpc::is_void_v) + rpc::apply(fn, args[id]); + else + rets[id] = rpc::apply(fn, args[id]); + } + } + + // Send any potentially modified pointer arguments back to the client. + rpc::finish_args(port, args, + rpc::make_index_sequence{}); + + // Copy back the return value of the function if one exists. If the function + // is void we send an empty packet to force synchronous behavior. + port.send_n(rets); +} +} // namespace rpc diff --git a/libc/shared/rpc_util.h b/libc/shared/rpc_util.h index 687814b7ff2a..5e2299367845 100644 --- a/libc/shared/rpc_util.h +++ b/libc/shared/rpc_util.h @@ -42,12 +42,65 @@ template struct type_constant { static inline constexpr T value = v; }; +/// Freestanding type trait helpers. +template struct remove_cv : type_identity {}; +template struct remove_cv : type_identity {}; +template using remove_cv_t = typename remove_cv::type; + +template struct remove_pointer : type_identity {}; +template struct remove_pointer : type_identity {}; +template using remove_pointer_t = typename remove_pointer::type; + +template struct remove_const : type_identity {}; +template struct remove_const : type_identity {}; +template using remove_const_t = typename remove_const::type; + template struct remove_reference : type_identity {}; template struct remove_reference : type_identity {}; template struct remove_reference : type_identity {}; +template +using remove_reference_t = typename remove_reference::type; template struct is_const : type_constant {}; template struct is_const : type_constant {}; +template RPC_ATTRS constexpr bool is_const_v = is_const::value; + +template struct is_pointer : type_constant {}; +template struct is_pointer : type_constant {}; +template +struct is_pointer : type_constant {}; +template +RPC_ATTRS constexpr bool is_pointer_v = is_pointer::value; + +template +struct is_same : type_constant {}; +template struct is_same : type_constant {}; +template +RPC_ATTRS constexpr bool is_same_v = is_same::value; + +template struct is_void : type_constant {}; +template <> struct is_void : type_constant {}; +template RPC_ATTRS constexpr bool is_void_v = is_void::value; + +template +struct is_trivially_copyable + : public type_constant {}; +template +RPC_ATTRS constexpr bool is_trivially_copyable_v = + is_trivially_copyable::value; + +template +struct is_trivially_constructible + : type_constant {}; +template +RPC_ATTRS constexpr bool is_trivially_constructible_v = + is_trivially_constructible::value; + +template struct conditional : type_identity {}; +template +struct conditional : type_identity {}; +template +using conditional_t = typename conditional::type; /// Freestanding implementation of std::move. template @@ -55,6 +108,29 @@ RPC_ATTRS constexpr typename remove_reference::type &&move(T &&t) { return static_cast::type &&>(t); } +/// Freestanding integer sequence. +template struct integer_sequence { + template using append = integer_sequence; +}; + +namespace detail { +template struct make_integer_sequence { + using type = + typename make_integer_sequence::type::template append; +}; +template struct make_integer_sequence { + using type = integer_sequence; +}; +} // namespace detail + +template +using index_sequence = integer_sequence; +template +using make_index_sequence = + typename detail::make_integer_sequence::type; +template +using index_sequence_for = make_index_sequence; + /// Freestanding implementation of std::forward. template RPC_ATTRS constexpr T &&forward(typename remove_reference::type &value) { @@ -150,6 +226,84 @@ public: RPC_ATTRS constexpr T &&operator*() && { return move(storage.stored_value); } }; +/// Minimal array type. +template struct array { + T elems[N]; + + RPC_ATTRS constexpr T *data() { return elems; } + RPC_ATTRS constexpr const T *data() const { return elems; } + RPC_ATTRS static constexpr uint64_t size() { return N; } + + RPC_ATTRS constexpr T &operator[](uint64_t i) { return elems[i]; } + RPC_ATTRS constexpr const T &operator[](uint64_t i) const { return elems[i]; } +}; + +/// Minimal tuple type. +template struct tuple; +template <> struct tuple<> {}; + +template +struct tuple : tuple { + Head head; + + RPC_ATTRS constexpr tuple() = default; + + template + RPC_ATTRS constexpr tuple &operator=(const tuple &other) { + head = other.get_head(); + this->get_tail() = other.get_tail(); + return *this; + } + + RPC_ATTRS constexpr tuple(const Head &h, const Tail &...t) + : tuple(t...), head(h) {} + + RPC_ATTRS constexpr Head &get_head() { return head; } + RPC_ATTRS constexpr const Head &get_head() const { return head; } + + RPC_ATTRS constexpr tuple &get_tail() { return *this; } + RPC_ATTRS constexpr const tuple &get_tail() const { return *this; } +}; + +template struct tuple_element; +template +struct tuple_element> + : tuple_element> {}; +template +struct tuple_element<0, tuple> { + using type = remove_cv_t>; +}; +template +using tuple_element_t = typename tuple_element::type; + +template +RPC_ATTRS constexpr auto &get(tuple &t) { + if constexpr (Idx == 0) + return t.get_head(); + else + return get(t.get_tail()); +} +template +RPC_ATTRS constexpr const auto &get(const tuple &t) { + if constexpr (Idx == 0) + return t.get_head(); + else + return get(t.get_tail()); +} + +namespace detail { +template +RPC_ATTRS auto apply(F &&f, Tuple &&t, index_sequence) { + return f(get(static_cast(t))...); +} +} // namespace detail + +template +RPC_ATTRS auto apply(F &&f, tuple &t) { + return detail::apply(static_cast(f), t, + make_index_sequence{}); +} + /// Suspend the thread briefly to assist the thread scheduler during busy loops. RPC_ATTRS void sleep_briefly() { #if __has_builtin(__nvvm_reflect) @@ -263,14 +417,34 @@ template RPC_ATTRS T *advance(T *ptr, U bytes) { } /// Wrapper around the optimal memory copy implementation for the target. -RPC_ATTRS void rpc_memcpy(void *dst, const void *src, size_t count) { +RPC_ATTRS void rpc_memcpy(void *dst, const void *src, uint64_t count) { __builtin_memcpy(dst, src, count); } -template RPC_ATTRS constexpr const T &max(const T &a, const T &b) { +/// Minimal string length function. +RPC_ATTRS constexpr uint64_t string_length(const char *s) { + const char *end = s; + for (; *end != '\0'; ++end) + ; + return static_cast(end - s + 1); +} + +/// Helper for dealing with function types. +template struct function_traits; +template struct function_traits { + using return_type = R; + using arg_types = rpc::tuple; + static constexpr uint64_t ARITY = sizeof...(Args); +}; + +template RPC_ATTRS constexpr T max(const T &a, const U &b) { return (a < b) ? b : a; } +template RPC_ATTRS constexpr T min(const T &a, const U &b) { + return (a < b) ? a : b; +} + } // namespace rpc #endif // LLVM_LIBC_SHARED_RPC_UTIL_H diff --git a/offload/test/libc/rpc_callback.c b/offload/test/libc/rpc_callback.c deleted file mode 100644 index 223b54eddd81..000000000000 --- a/offload/test/libc/rpc_callback.c +++ /dev/null @@ -1,66 +0,0 @@ -// RUN: %libomptarget-compilexx-run-and-check-generic - -// REQUIRES: libc -// REQUIRES: gpu - -#include -#include -#include - -// CHECK: PASS - -// This should be present in-tree relative to the test directory. If someone is -// using a partial tree just pass the test. -#if !__has_include(<../../libc/shared/rpc.h>) -int main() { printf("PASS\n"); } -#else -#include <../../libc/shared/rpc.h> - -extern "C" void __tgt_register_rpc_callback(unsigned (*Callback)(void *, - unsigned)); -constexpr uint32_t RPC_TEST_OPCODE = 1; - -template rpc::Status handleOpcodes(rpc::Server::Port &Port) { - switch (Port.get_opcode()) { - case RPC_TEST_OPCODE: { - Port.recv( - [&](rpc::Buffer *Buffer, uint32_t) { assert(Buffer->data[0] == 42); }); - Port.send([&](rpc::Buffer *, uint32_t) {}); - break; - } - default: - return rpc::RPC_UNHANDLED_OPCODE; - break; - } - return rpc::RPC_SUCCESS; -} - -static uint32_t handleOffloadOpcodes(void *Raw, uint32_t NumLanes) { - rpc::Server::Port &Port = *reinterpret_cast(Raw); - if (NumLanes == 1) - return handleOpcodes<1>(Port); - else if (NumLanes == 32) - return handleOpcodes<32>(Port); - else if (NumLanes == 64) - return handleOpcodes<64>(Port); - else - return rpc::RPC_ERROR; -} - -[[gnu::weak]] rpc::Client client asm("__llvm_rpc_client"); -#pragma omp declare target to(client) device_type(nohost) - -void __tgt_register_rpc_callback(unsigned (*Callback)(void *, unsigned)); - -int main() { - __tgt_register_rpc_callback(&handleOffloadOpcodes); -#pragma omp target - { - rpc::Client::Port Port = client.open(); - Port.send([=](rpc::Buffer *buffer, uint32_t) { buffer->data[0] = 42; }); - Port.recv([](rpc::Buffer *, uint32_t) {}); - Port.close(); - } - printf("PASS\n"); -} -#endif diff --git a/offload/test/libc/rpc_callback.cpp b/offload/test/libc/rpc_callback.cpp new file mode 100644 index 000000000000..3246e5a34bb0 --- /dev/null +++ b/offload/test/libc/rpc_callback.cpp @@ -0,0 +1,205 @@ +// RUN: %libomptarget-compilexx-run-and-check-generic +// REQUIRES: libc +// REQUIRES: gpu + +#include +#include +#include +#include +#include + +// CHECK: PASS + +// If the RPC headers are not present, just pass the test. +#if !__has_include(<../../libc/shared/rpc.h>) +int main() { printf("PASS\n"); } +#else + +#include <../../libc/shared/rpc.h> +#include <../../libc/shared/rpc_dispatch.h> + +[[gnu::weak]] rpc::Client client asm("__llvm_rpc_client"); +#pragma omp declare target to(client) device_type(nohost) + +//===------------------------------------------------------------------------=== +// Opcodes. +//===------------------------------------------------------------------------=== + +constexpr uint32_t FOO_OPCODE = 1; +constexpr uint32_t VOID_OPCODE = 2; +constexpr uint32_t WRITEBACK_OPCODE = 3; +constexpr uint32_t CONST_PTR_OPCODE = 4; +constexpr uint32_t STRING_OPCODE = 5; +constexpr uint32_t EMPTY_OPCODE = 6; +constexpr uint32_t DIVERGENT_OPCODE = 7; + +//===------------------------------------------------------------------------=== +// Server-side implementations. +//===------------------------------------------------------------------------=== + +struct S { + int arr[4]; +}; + +// 1. Non-pointer arguments, non-void return. +int foo(int x, double d, char c) { + assert(x == 42); + assert(d == 0.0); + assert(c == 'c'); + return -1; +} + +// 2. Void return type. +void void_fn(int x) { assert(x == 7); } + +// 3. Write-back pointer. +void writeback_fn(int *out) { + assert(out != nullptr && *out == 42); + *out = 99; +} + +// 4. Const pointer. +int sum_const(const S *p) { + int s = 0; + for (int i = 0; i < 4; ++i) + s += p->arr[i]; + return s; +} + +// 5. const char * string. +int c_string(const char *s) { + assert(s != nullptr); + assert(strcmp(s, "hello") == 0); + return strlen(s); +} + +// 6. Empty function. +int empty() { return 42; } + +// 7. Divergent values. +void divergent(int *) {} + +//===------------------------------------------------------------------------=== +// RPC client dispatch. +//===------------------------------------------------------------------------=== + +#pragma omp begin declare variant match(device = {kind(gpu)}) +int foo(int x, double d, char c) { + return rpc::dispatch(client, foo, x, d, c); +} + +void void_fn(int x) { rpc::dispatch(client, void_fn, x); } + +void writeback_fn(int *out) { + rpc::dispatch(client, writeback_fn, out); +} + +int sum_const(const S *p) { + return rpc::dispatch(client, sum_const, p); +} + +int c_string(const char *s) { + return rpc::dispatch(client, c_string, s); +} + +int empty() { return rpc::dispatch(client, empty); } + +void divergent(int *p) { + rpc::dispatch(client, divergent, p); +} +#pragma omp end declare variant + +//===------------------------------------------------------------------------=== +// RPC server dispatch. +//===------------------------------------------------------------------------=== + +template +rpc::Status handleOpcodesImpl(rpc::Server::Port &Port) { + switch (Port.get_opcode()) { + case FOO_OPCODE: + rpc::invoke(foo, Port); + break; + case VOID_OPCODE: + rpc::invoke(void_fn, Port); + break; + case WRITEBACK_OPCODE: + rpc::invoke(writeback_fn, Port); + break; + case CONST_PTR_OPCODE: + rpc::invoke(sum_const, Port); + break; + case STRING_OPCODE: + rpc::invoke(c_string, Port); + break; + case EMPTY_OPCODE: + rpc::invoke(empty, Port); + break; + case DIVERGENT_OPCODE: + rpc::invoke(divergent, Port); + break; + default: + return rpc::RPC_UNHANDLED_OPCODE; + } + return rpc::RPC_SUCCESS; +} + +static uint32_t handleOpcodes(void *raw, uint32_t numLanes) { + rpc::Server::Port &Port = *reinterpret_cast(raw); + if (numLanes == 1) + return handleOpcodesImpl<1>(Port); + else if (numLanes == 32) + return handleOpcodesImpl<32>(Port); + else if (numLanes == 64) + return handleOpcodesImpl<64>(Port); + else + return rpc::RPC_ERROR; +} + +extern "C" void __tgt_register_rpc_callback(unsigned (*callback)(void *, + unsigned)); + +[[gnu::constructor]] void register_callback() { + __tgt_register_rpc_callback(&handleOpcodes); +} + +int main() { + +#pragma omp target +#pragma omp parallel num_threads(32) + { + // 1. Non-pointer return. + assert(foo(42, 0.0, 'c') == -1); + + // 2. Void return. + void_fn(7); + + // 3. Write-back pointer. + int value = 42; + writeback_fn(&value); + assert(value == 99); + + // 4. Const pointer. + S s{1, 2, 3, 4}; + int sum = sum_const(&s); + assert(sum == 10); + + // 5. const char * string. + const char *msg = "hello"; + int len = c_string(msg); + assert(len == 5); + + // 6. No arguments. + int ret = empty(); + assert(ret == 42); + + // 7. Divergent values. + int id = omp_get_thread_num(); + if (id % 2) + divergent(&id); + assert(id == omp_get_thread_num()); + } + + printf("PASS\n"); +} + +#endif