[libc] Add RPC helpers for dispatching functions to the host (#179085)
Summary: The RPC interface is useful for forwarding functions. This PR adds helper functions for doing a completely bare forwarding of a function from the client to the server. This is intended to facilitate heterogenous libraries that implement host functions on the GPU (like MPI or Fortran).
This commit is contained in:
parent
3f73f839e2
commit
6d6feb7655
@ -113,10 +113,10 @@ done. It can be omitted if asynchronous execution is desired.
|
||||
void rpc_host_call(void *fn, void *data, size_t size) {
|
||||
rpc::Client::Port port = rpc::client.open<RPC_HOST_CALL>();
|
||||
port.send_n(data, size);
|
||||
port.send([=](rpc::Buffer *buffer) {
|
||||
port.send([=](rpc::Buffer *buffer, uint32_t) {
|
||||
buffer->data[0] = reinterpret_cast<uintptr_t>(fn);
|
||||
});
|
||||
port.recv([](rpc::Buffer *) {});
|
||||
port.recv([](rpc::Buffer *, uint32_t) {});
|
||||
port.close();
|
||||
}
|
||||
|
||||
@ -131,7 +131,7 @@ call a function pointer provided by the client.
|
||||
In this example, the server simply runs forever in a separate thread for
|
||||
brevity's sake. Because the client is a GPU potentially handling several threads
|
||||
at once, the server needs to loop over all the active threads on the GPU. We
|
||||
abstract this into the ``lane_size`` variable, which is simply the device's warp
|
||||
abstract this into the ``num_lanes`` variable, which is simply the device's warp
|
||||
or wavefront size. The identifier is simply the threads index into the current
|
||||
warp or wavefront. We allocate memory to copy the struct data into, and then
|
||||
call the given function pointer with that copied data. The final send simply
|
||||
@ -147,8 +147,8 @@ data.
|
||||
|
||||
switch(port->get_opcode()) {
|
||||
case RPC_HOST_CALL: {
|
||||
uint64_t sizes[LANE_SIZE];
|
||||
void *args[LANE_SIZE];
|
||||
uint64_t sizes[NUM_LANES];
|
||||
void *args[NUM_LANES];
|
||||
port->recv_n(args, sizes, [&](uint64_t size) { return new char[size]; });
|
||||
port->recv([&](rpc::Buffer *buffer, uint32_t id) {
|
||||
reinterpret_cast<void (*)(void *)>(buffer->data[0])(args[id]);
|
||||
@ -162,8 +162,47 @@ data.
|
||||
port->recv([](rpc::Buffer *) {});
|
||||
break;
|
||||
}
|
||||
port->close();
|
||||
}
|
||||
|
||||
Function Dispatch
|
||||
-----------------
|
||||
|
||||
There are cases where the client wants to simply execute a function as-is on the
|
||||
server. A helper function is provided to make this case almost automatic. By
|
||||
default, all memory is assumed to live privately on the client. Pointer
|
||||
arguments will be copied between the client and server for correctness. Pointers
|
||||
to void will all be treated as opaque pointers and copied by-value. Constant
|
||||
character pointers will be treated as C-strings and copied using its length.
|
||||
Functions returning void will wait for the server to complete execution rather
|
||||
than submitting asynchronously.
|
||||
|
||||
.. code-block:: c++
|
||||
|
||||
double fn(int x, long y, char c, double d);
|
||||
|
||||
// Client-side dispatch.
|
||||
double fn(int x, long y, char c, double d) {
|
||||
return rpc::dispatch<OPCODE>(client, fn, x, y, c, d);
|
||||
}
|
||||
|
||||
// Server-side handling.
|
||||
for(;;) {
|
||||
auto port = server.try_open(index);
|
||||
if (!port)
|
||||
return continue;
|
||||
|
||||
switch(port->get_opcode()) {
|
||||
case OPCODE:
|
||||
rpc::invoke<NUM_LANES>(fn, *port);
|
||||
default:
|
||||
port->recv([](rpc::Buffer *) {});
|
||||
break;
|
||||
}
|
||||
port->close();
|
||||
}
|
||||
|
||||
|
||||
CUDA Server Example
|
||||
-------------------
|
||||
|
||||
|
||||
@ -318,10 +318,19 @@ public:
|
||||
template <typename A>
|
||||
RPC_ATTRS void recv_n(void **dst, uint64_t *size, A &&alloc);
|
||||
|
||||
template <typename Ty> RPC_ATTRS void send_n(const Ty *src);
|
||||
template <typename Ty> RPC_ATTRS void recv_n(Ty *dst);
|
||||
|
||||
RPC_ATTRS uint32_t get_opcode() const { return process.header[index].opcode; }
|
||||
|
||||
RPC_ATTRS uint32_t get_index() const { return index; }
|
||||
|
||||
RPC_ATTRS uint64_t get_lane_mask() const {
|
||||
if constexpr (T)
|
||||
return process.header[index].mask;
|
||||
return lane_mask;
|
||||
}
|
||||
|
||||
RPC_ATTRS void close() {
|
||||
// Wait for all lanes to finish using the port.
|
||||
rpc::sync_lane(lane_mask);
|
||||
@ -392,7 +401,7 @@ template <bool T> template <typename F> RPC_ATTRS void Port<T>::send(F fill) {
|
||||
process.wait_for_ownership(lane_mask, index, out, in);
|
||||
|
||||
// Apply the \p fill function to initialize the buffer and release the memory.
|
||||
invoke_rpc(fill, lane_size, process.header[index].mask,
|
||||
invoke_rpc(fill, lane_size, get_lane_mask(),
|
||||
process.get_packet(index, lane_size));
|
||||
out = process.invert_outbox(index, out);
|
||||
owns_buffer = false;
|
||||
@ -414,7 +423,7 @@ template <bool T> template <typename U> RPC_ATTRS void Port<T>::recv(U use) {
|
||||
process.wait_for_ownership(lane_mask, index, out, in);
|
||||
|
||||
// Apply the \p use function to read the memory out of the buffer.
|
||||
invoke_rpc(use, lane_size, process.header[index].mask,
|
||||
invoke_rpc(use, lane_size, get_lane_mask(),
|
||||
process.get_packet(index, lane_size));
|
||||
receive = true;
|
||||
owns_buffer = true;
|
||||
@ -509,6 +518,30 @@ RPC_ATTRS void Port<T>::recv_n(void **dst, uint64_t *size, A &&alloc) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Simplified version of `send_n` where the size is a known constant.
|
||||
template <bool T>
|
||||
template <typename Ty>
|
||||
RPC_ATTRS void Port<T>::send_n(const Ty *src) {
|
||||
for (uint64_t idx = 0; idx < sizeof(Ty); idx += sizeof(Buffer::data)) {
|
||||
const uint64_t bytes = rpc::min(sizeof(Ty) - idx, sizeof(Buffer::data));
|
||||
send([&](Buffer *buffer, uint32_t id) {
|
||||
rpc_memcpy(buffer->data, advance(&lane_value(src, id), idx), bytes);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Simplified version of `recv_n` where the size is a known constant.
|
||||
template <bool T>
|
||||
template <typename Ty>
|
||||
RPC_ATTRS void Port<T>::recv_n(Ty *dst) {
|
||||
for (uint64_t idx = 0; idx < sizeof(Ty); idx += sizeof(Buffer::data)) {
|
||||
const uint64_t bytes = rpc::min(sizeof(Ty) - idx, sizeof(Buffer::data));
|
||||
recv([&](Buffer *buffer, uint32_t id) {
|
||||
rpc_memcpy(advance(&lane_value(dst, id), idx), buffer->data, bytes);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Continually attempts to open a port to use as the client. The client can
|
||||
/// only open a port if we find an index that is in a valid sending state. That
|
||||
/// is, there are send operations pending that haven't been serviced on this
|
||||
@ -590,7 +623,6 @@ RPC_ATTRS Server::Port Server::open(uint32_t lane_size) {
|
||||
}
|
||||
}
|
||||
|
||||
#undef RPC_ATTRS
|
||||
#if !__has_builtin(__scoped_atomic_load_n)
|
||||
#undef __scoped_atomic_load_n
|
||||
#undef __scoped_atomic_store_n
|
||||
|
||||
258
libc/shared/rpc_dispatch.h
Normal file
258
libc/shared/rpc_dispatch.h
Normal file
@ -0,0 +1,258 @@
|
||||
//===-- Helper functions for client / server dispatch -----------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "rpc.h"
|
||||
#include "rpc_util.h"
|
||||
|
||||
namespace rpc {
|
||||
namespace {
|
||||
|
||||
// Forward declarations needed for the server, we assume these are present.
|
||||
extern "C" void *malloc(uint64_t);
|
||||
extern "C" void free(void *);
|
||||
|
||||
// Traits to convert between a tuple and binary representation of an argument
|
||||
// list.
|
||||
template <typename... Ts> struct tuple_bytes {
|
||||
static constexpr uint64_t SIZE = rpc::max(1ul, (0 + ... + sizeof(Ts)));
|
||||
using array_type = rpc::array<uint8_t, SIZE>;
|
||||
|
||||
template <uint64_t... Is>
|
||||
RPC_ATTRS static constexpr array_type pack_impl(rpc::tuple<Ts...> t,
|
||||
rpc::index_sequence<Is...>) {
|
||||
array_type out{};
|
||||
uint8_t *p = out.data();
|
||||
((rpc::rpc_memcpy(p, &rpc::get<Is>(t), sizeof(Ts)), p += sizeof(Ts)), ...);
|
||||
return out;
|
||||
}
|
||||
|
||||
RPC_ATTRS static constexpr array_type pack(rpc::tuple<Ts...> t) {
|
||||
return pack_impl(t, rpc::index_sequence_for<Ts...>{});
|
||||
}
|
||||
|
||||
template <uint64_t... Is>
|
||||
RPC_ATTRS static constexpr rpc::tuple<Ts...>
|
||||
unpack_impl(const uint8_t *data, rpc::index_sequence<Is...>) {
|
||||
rpc::tuple<Ts...> t{};
|
||||
const uint8_t *p = data;
|
||||
((rpc::rpc_memcpy(&rpc::get<Is>(t), p, sizeof(Ts)), p += sizeof(Ts)), ...);
|
||||
return t;
|
||||
}
|
||||
|
||||
RPC_ATTRS static constexpr rpc::tuple<Ts...> unpack(const array_type &a) {
|
||||
return unpack_impl(a.data(), rpc::index_sequence_for<Ts...>{});
|
||||
}
|
||||
};
|
||||
template <typename... Ts>
|
||||
struct tuple_bytes<rpc::tuple<Ts...>> : tuple_bytes<Ts...> {};
|
||||
|
||||
// Client-side dispatch of pointer values. We copy the memory associated with
|
||||
// the pointer to the server and recieve back a server-side pointer to replace
|
||||
// the client-side pointer in the argument list.
|
||||
template <uint64_t Idx, typename Tuple>
|
||||
RPC_ATTRS constexpr void prepare_arg(rpc::Client::Port &port, Tuple &t) {
|
||||
using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
|
||||
if constexpr (rpc::is_pointer_v<ArgTy> &&
|
||||
!rpc::is_void_v<rpc::remove_pointer_t<ArgTy>>) {
|
||||
// We assume all constant character arrays are C-strings.
|
||||
uint64_t size{};
|
||||
if constexpr (rpc::is_same_v<ArgTy, const char *>)
|
||||
size = rpc::string_length(rpc::get<Idx>(t));
|
||||
else
|
||||
size = sizeof(rpc::remove_pointer_t<ArgTy>);
|
||||
port.send_n(rpc::get<Idx>(t), size);
|
||||
port.recv([&](rpc::Buffer *buffer, uint32_t) {
|
||||
rpc::get<Idx>(t) = *reinterpret_cast<ArgTy *>(buffer->data);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Server-side handling of pointer arguments. We recieve the memory into a
|
||||
// temporary buffer and pass a pointer to this new memory back to the client.
|
||||
template <uint32_t NUM_LANES, typename Tuple, uint64_t Idx>
|
||||
RPC_ATTRS constexpr void prepare_arg(rpc::Server::Port &port) {
|
||||
using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
|
||||
if constexpr (rpc::is_pointer_v<ArgTy> &&
|
||||
!rpc::is_void_v<rpc::remove_pointer_t<ArgTy>>) {
|
||||
void *args[NUM_LANES]{};
|
||||
uint64_t sizes[NUM_LANES]{};
|
||||
port.recv_n(args, sizes, [](uint64_t size) {
|
||||
if constexpr (rpc::is_same_v<ArgTy, const char *>)
|
||||
return malloc(size);
|
||||
else
|
||||
return malloc(
|
||||
sizeof(rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>>));
|
||||
});
|
||||
port.send([&](rpc::Buffer *buffer, uint32_t id) {
|
||||
*reinterpret_cast<ArgTy *>(buffer->data) = static_cast<ArgTy>(args[id]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Client-side finalization of pointer arguments. If the type is not constant we
|
||||
// must copy back any potential modifications the invoked function made to that
|
||||
// pointer.
|
||||
template <uint64_t Idx, typename Tuple>
|
||||
RPC_ATTRS constexpr void finish_arg(rpc::Client::Port &port, Tuple &t) {
|
||||
using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
|
||||
using MemoryTy = rpc::remove_const_t<rpc::remove_pointer_t<ArgTy>> *;
|
||||
if constexpr (rpc::is_pointer_v<ArgTy> && !rpc::is_const_v<ArgTy> &&
|
||||
!rpc::is_void_v<rpc::remove_pointer_t<ArgTy>>) {
|
||||
uint64_t size{};
|
||||
void *buf{};
|
||||
port.recv_n(&buf, &size, [&](uint64_t) {
|
||||
return const_cast<MemoryTy>(rpc::get<Idx>(t));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Server-side finalization of pointer arguments. We copy any potential
|
||||
// modifications to the value back to the client unless it was a constant. We
|
||||
// can also free the associated memory.
|
||||
template <uint32_t NUM_LANES, uint64_t Idx, typename Tuple>
|
||||
RPC_ATTRS constexpr void finish_arg(rpc::Server::Port &port,
|
||||
Tuple (&t)[NUM_LANES]) {
|
||||
using ArgTy = rpc::tuple_element_t<Idx, Tuple>;
|
||||
if constexpr (rpc::is_pointer_v<ArgTy> && !rpc::is_const_v<ArgTy> &&
|
||||
!rpc::is_void_v<rpc::remove_pointer_t<ArgTy>>) {
|
||||
const void *buffer[NUM_LANES]{};
|
||||
size_t sizes[NUM_LANES]{};
|
||||
for (uint32_t id = 0; id < NUM_LANES; ++id) {
|
||||
if (port.get_lane_mask() & (uint64_t(1) << id)) {
|
||||
buffer[id] = rpc::get<Idx>(t[id]);
|
||||
sizes[id] = sizeof(rpc::remove_pointer_t<ArgTy>);
|
||||
}
|
||||
}
|
||||
port.send_n(buffer, sizes);
|
||||
}
|
||||
|
||||
if constexpr (rpc::is_pointer_v<ArgTy> &&
|
||||
!rpc::is_void_v<rpc::remove_pointer_t<ArgTy>>) {
|
||||
for (uint32_t id = 0; id < NUM_LANES; ++id) {
|
||||
if (port.get_lane_mask() & (uint64_t(1) << id))
|
||||
free(const_cast<void *>(
|
||||
static_cast<const void *>(rpc::get<Idx>(t[id]))));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate over the tuple list of arguments to see if we need to forward any.
|
||||
// The current forwarding is somewhat inefficient as each pointer is an
|
||||
// individual RPC call.
|
||||
template <typename Tuple, uint64_t... Is>
|
||||
RPC_ATTRS constexpr void prepare_args(rpc::Client::Port &port, Tuple &t,
|
||||
rpc::index_sequence<Is...>) {
|
||||
(prepare_arg<Is>(port, t), ...);
|
||||
}
|
||||
template <uint32_t NUM_LANES, typename Tuple, uint64_t... Is>
|
||||
RPC_ATTRS constexpr void prepare_args(rpc::Server::Port &port,
|
||||
rpc::index_sequence<Is...>) {
|
||||
(prepare_arg<NUM_LANES, Tuple, Is>(port), ...);
|
||||
}
|
||||
|
||||
// Performs the preparation in reverse, copying back any modified values.
|
||||
template <typename Tuple, uint64_t... Is>
|
||||
RPC_ATTRS constexpr void finish_args(rpc::Client::Port &port, Tuple &&t,
|
||||
rpc::index_sequence<Is...>) {
|
||||
(finish_arg<Is>(port, t), ...);
|
||||
}
|
||||
template <uint32_t NUM_LANES, typename Tuple, uint64_t... Is>
|
||||
RPC_ATTRS constexpr void finish_args(rpc::Server::Port &port,
|
||||
Tuple (&t)[NUM_LANES],
|
||||
rpc::index_sequence<Is...>) {
|
||||
(finish_arg<NUM_LANES, Is>(port, t), ...);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Dispatch a function call to the server through the RPC mechanism. Copies the
|
||||
// argument list through the RPC interface.
|
||||
template <uint32_t OPCODE, typename FnTy, typename... CallArgs>
|
||||
RPC_ATTRS constexpr typename function_traits<FnTy>::return_type
|
||||
dispatch(rpc::Client &client, FnTy, CallArgs... args) {
|
||||
using Traits = function_traits<FnTy>;
|
||||
using RetTy = typename Traits::return_type;
|
||||
using TupleTy = typename Traits::arg_types;
|
||||
using Bytes = tuple_bytes<CallArgs...>;
|
||||
|
||||
static_assert(sizeof...(CallArgs) == Traits::ARITY,
|
||||
"Argument count mismatch");
|
||||
static_assert(((rpc::is_trivially_constructible_v<CallArgs> &&
|
||||
rpc::is_trivially_copyable_v<CallArgs>) &&
|
||||
...),
|
||||
"Must be a trivial type");
|
||||
|
||||
auto port = client.open<OPCODE>();
|
||||
|
||||
// Copy over any pointer arguments by walking the argument list.
|
||||
TupleTy arg_tuple{rpc::forward<CallArgs>(args)...};
|
||||
rpc::prepare_args(port, arg_tuple, rpc::make_index_sequence<Traits::ARITY>{});
|
||||
|
||||
// Compress the argument list to a binary stream and send it to the server.
|
||||
auto bytes = Bytes::pack(arg_tuple);
|
||||
port.send_n(&bytes);
|
||||
|
||||
// Copy back any potentially modified pointer arguments and the return value.
|
||||
rpc::finish_args(port, TupleTy{rpc::forward<CallArgs>(args)...},
|
||||
rpc::make_index_sequence<Traits::ARITY>{});
|
||||
|
||||
// Copy back the final function return value.
|
||||
using BufferTy = rpc::conditional_t<rpc::is_void_v<RetTy>, uint8_t, RetTy>;
|
||||
BufferTy ret{};
|
||||
port.recv_n(&ret);
|
||||
port.close();
|
||||
|
||||
if constexpr (!rpc::is_void_v<RetTy>)
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Invoke a function on the server on behalf of the client. Recieves the
|
||||
// arguments through the interface and forwards them to the function.
|
||||
template <uint32_t NUM_LANES, typename FnTy>
|
||||
RPC_ATTRS constexpr void invoke(FnTy fn, rpc::Server::Port &port) {
|
||||
using Traits = function_traits<FnTy>;
|
||||
using RetTy = typename Traits::return_type;
|
||||
using TupleTy = typename Traits::arg_types;
|
||||
using Bytes = tuple_bytes<TupleTy>;
|
||||
|
||||
// Recieve pointer data from the host and pack it in server-side memory.
|
||||
rpc::prepare_args<NUM_LANES, TupleTy>(
|
||||
port, rpc::make_index_sequence<Traits::ARITY>{});
|
||||
|
||||
// Get the argument list from the client.
|
||||
typename Bytes::array_type arg_buf[NUM_LANES]{};
|
||||
port.recv_n(arg_buf);
|
||||
|
||||
// Convert the recieved arguments into an invocable argument list.
|
||||
TupleTy args[NUM_LANES];
|
||||
for (uint32_t id = 0; id < NUM_LANES; ++id) {
|
||||
if (port.get_lane_mask() & (uint64_t(1) << id))
|
||||
args[id] = Bytes::unpack(arg_buf[id]);
|
||||
}
|
||||
|
||||
// Execute the function with the provided arguments and send back any copies
|
||||
// made for pointer data.
|
||||
using BufferTy = rpc::conditional_t<rpc::is_void_v<RetTy>, uint8_t, RetTy>;
|
||||
BufferTy rets[NUM_LANES]{};
|
||||
for (uint32_t id = 0; id < NUM_LANES; ++id) {
|
||||
if (port.get_lane_mask() & (uint64_t(1) << id)) {
|
||||
if constexpr (rpc::is_void_v<RetTy>)
|
||||
rpc::apply(fn, args[id]);
|
||||
else
|
||||
rets[id] = rpc::apply(fn, args[id]);
|
||||
}
|
||||
}
|
||||
|
||||
// Send any potentially modified pointer arguments back to the client.
|
||||
rpc::finish_args<NUM_LANES>(port, args,
|
||||
rpc::make_index_sequence<Traits::ARITY>{});
|
||||
|
||||
// Copy back the return value of the function if one exists. If the function
|
||||
// is void we send an empty packet to force synchronous behavior.
|
||||
port.send_n(rets);
|
||||
}
|
||||
} // namespace rpc
|
||||
@ -42,12 +42,65 @@ template <class T, T v> struct type_constant {
|
||||
static inline constexpr T value = v;
|
||||
};
|
||||
|
||||
/// Freestanding type trait helpers.
|
||||
template <class T> struct remove_cv : type_identity<T> {};
|
||||
template <class T> struct remove_cv<const T> : type_identity<T> {};
|
||||
template <class T> using remove_cv_t = typename remove_cv<T>::type;
|
||||
|
||||
template <class T> struct remove_pointer : type_identity<T> {};
|
||||
template <class T> struct remove_pointer<T *> : type_identity<T> {};
|
||||
template <class T> using remove_pointer_t = typename remove_pointer<T>::type;
|
||||
|
||||
template <class T> struct remove_const : type_identity<T> {};
|
||||
template <class T> struct remove_const<const T> : type_identity<T> {};
|
||||
template <class T> using remove_const_t = typename remove_const<T>::type;
|
||||
|
||||
template <class T> struct remove_reference : type_identity<T> {};
|
||||
template <class T> struct remove_reference<T &> : type_identity<T> {};
|
||||
template <class T> struct remove_reference<T &&> : type_identity<T> {};
|
||||
template <class T>
|
||||
using remove_reference_t = typename remove_reference<T>::type;
|
||||
|
||||
template <class T> struct is_const : type_constant<bool, false> {};
|
||||
template <class T> struct is_const<const T> : type_constant<bool, true> {};
|
||||
template <class T> RPC_ATTRS constexpr bool is_const_v = is_const<T>::value;
|
||||
|
||||
template <typename T> struct is_pointer : type_constant<bool, false> {};
|
||||
template <typename T> struct is_pointer<T *> : type_constant<bool, true> {};
|
||||
template <typename T>
|
||||
struct is_pointer<T *const> : type_constant<bool, true> {};
|
||||
template <typename T>
|
||||
RPC_ATTRS constexpr bool is_pointer_v = is_pointer<T>::value;
|
||||
|
||||
template <typename T, typename U>
|
||||
struct is_same : type_constant<bool, false> {};
|
||||
template <typename T> struct is_same<T, T> : type_constant<bool, true> {};
|
||||
template <typename T, typename U>
|
||||
RPC_ATTRS constexpr bool is_same_v = is_same<T, U>::value;
|
||||
|
||||
template <class T> struct is_void : type_constant<bool, false> {};
|
||||
template <> struct is_void<void> : type_constant<bool, true> {};
|
||||
template <typename T> RPC_ATTRS constexpr bool is_void_v = is_void<T>::value;
|
||||
|
||||
template <class T>
|
||||
struct is_trivially_copyable
|
||||
: public type_constant<bool, __is_trivially_copyable(T)> {};
|
||||
template <class T>
|
||||
RPC_ATTRS constexpr bool is_trivially_copyable_v =
|
||||
is_trivially_copyable<T>::value;
|
||||
|
||||
template <class T, class... Args>
|
||||
struct is_trivially_constructible
|
||||
: type_constant<bool, __is_trivially_constructible(T, Args...)> {};
|
||||
template <class T, class... Args>
|
||||
RPC_ATTRS constexpr bool is_trivially_constructible_v =
|
||||
is_trivially_constructible<T>::value;
|
||||
|
||||
template <bool B, class T, class F> struct conditional : type_identity<T> {};
|
||||
template <class T, class F>
|
||||
struct conditional<false, T, F> : type_identity<F> {};
|
||||
template <bool B, class T, class F>
|
||||
using conditional_t = typename conditional<B, T, F>::type;
|
||||
|
||||
/// Freestanding implementation of std::move.
|
||||
template <class T>
|
||||
@ -55,6 +108,29 @@ RPC_ATTRS constexpr typename remove_reference<T>::type &&move(T &&t) {
|
||||
return static_cast<typename remove_reference<T>::type &&>(t);
|
||||
}
|
||||
|
||||
/// Freestanding integer sequence.
|
||||
template <typename T, T... Ints> struct integer_sequence {
|
||||
template <T Next> using append = integer_sequence<T, Ints..., Next>;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <typename T, int N> struct make_integer_sequence {
|
||||
using type =
|
||||
typename make_integer_sequence<T, N - 1>::type::template append<N>;
|
||||
};
|
||||
template <typename T> struct make_integer_sequence<T, -1> {
|
||||
using type = integer_sequence<T>;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <uint64_t... Ints>
|
||||
using index_sequence = integer_sequence<uint64_t, Ints...>;
|
||||
template <int N>
|
||||
using make_index_sequence =
|
||||
typename detail::make_integer_sequence<uint64_t, N - 1>::type;
|
||||
template <typename... Ts>
|
||||
using index_sequence_for = make_index_sequence<sizeof...(Ts)>;
|
||||
|
||||
/// Freestanding implementation of std::forward.
|
||||
template <typename T>
|
||||
RPC_ATTRS constexpr T &&forward(typename remove_reference<T>::type &value) {
|
||||
@ -150,6 +226,84 @@ public:
|
||||
RPC_ATTRS constexpr T &&operator*() && { return move(storage.stored_value); }
|
||||
};
|
||||
|
||||
/// Minimal array type.
|
||||
template <typename T, uint64_t N> struct array {
|
||||
T elems[N];
|
||||
|
||||
RPC_ATTRS constexpr T *data() { return elems; }
|
||||
RPC_ATTRS constexpr const T *data() const { return elems; }
|
||||
RPC_ATTRS static constexpr uint64_t size() { return N; }
|
||||
|
||||
RPC_ATTRS constexpr T &operator[](uint64_t i) { return elems[i]; }
|
||||
RPC_ATTRS constexpr const T &operator[](uint64_t i) const { return elems[i]; }
|
||||
};
|
||||
|
||||
/// Minimal tuple type.
|
||||
template <typename... Ts> struct tuple;
|
||||
template <> struct tuple<> {};
|
||||
|
||||
template <typename Head, typename... Tail>
|
||||
struct tuple<Head, Tail...> : tuple<Tail...> {
|
||||
Head head;
|
||||
|
||||
RPC_ATTRS constexpr tuple() = default;
|
||||
|
||||
template <typename OHead, typename... OTail>
|
||||
RPC_ATTRS constexpr tuple &operator=(const tuple<OHead, OTail...> &other) {
|
||||
head = other.get_head();
|
||||
this->get_tail() = other.get_tail();
|
||||
return *this;
|
||||
}
|
||||
|
||||
RPC_ATTRS constexpr tuple(const Head &h, const Tail &...t)
|
||||
: tuple<Tail...>(t...), head(h) {}
|
||||
|
||||
RPC_ATTRS constexpr Head &get_head() { return head; }
|
||||
RPC_ATTRS constexpr const Head &get_head() const { return head; }
|
||||
|
||||
RPC_ATTRS constexpr tuple<Tail...> &get_tail() { return *this; }
|
||||
RPC_ATTRS constexpr const tuple<Tail...> &get_tail() const { return *this; }
|
||||
};
|
||||
|
||||
template <size_t Idx, typename T> struct tuple_element;
|
||||
template <size_t Idx, typename Head, typename... Tail>
|
||||
struct tuple_element<Idx, tuple<Head, Tail...>>
|
||||
: tuple_element<Idx - 1, tuple<Tail...>> {};
|
||||
template <typename Head, typename... Tail>
|
||||
struct tuple_element<0, tuple<Head, Tail...>> {
|
||||
using type = remove_cv_t<remove_reference_t<Head>>;
|
||||
};
|
||||
template <size_t Idx, typename T>
|
||||
using tuple_element_t = typename tuple_element<Idx, T>::type;
|
||||
|
||||
template <uint64_t Idx, typename Head, typename... Tail>
|
||||
RPC_ATTRS constexpr auto &get(tuple<Head, Tail...> &t) {
|
||||
if constexpr (Idx == 0)
|
||||
return t.get_head();
|
||||
else
|
||||
return get<Idx - 1>(t.get_tail());
|
||||
}
|
||||
template <uint64_t Idx, typename Head, typename... Tail>
|
||||
RPC_ATTRS constexpr const auto &get(const tuple<Head, Tail...> &t) {
|
||||
if constexpr (Idx == 0)
|
||||
return t.get_head();
|
||||
else
|
||||
return get<Idx - 1>(t.get_tail());
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <typename F, typename Tuple, uint64_t... Is>
|
||||
RPC_ATTRS auto apply(F &&f, Tuple &&t, index_sequence<Is...>) {
|
||||
return f(get<Is>(static_cast<Tuple &&>(t))...);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename... Ts>
|
||||
RPC_ATTRS auto apply(F &&f, tuple<Ts...> &t) {
|
||||
return detail::apply(static_cast<F &&>(f), t,
|
||||
make_index_sequence<sizeof...(Ts)>{});
|
||||
}
|
||||
|
||||
/// Suspend the thread briefly to assist the thread scheduler during busy loops.
|
||||
RPC_ATTRS void sleep_briefly() {
|
||||
#if __has_builtin(__nvvm_reflect)
|
||||
@ -263,14 +417,34 @@ template <typename T, typename U> RPC_ATTRS T *advance(T *ptr, U bytes) {
|
||||
}
|
||||
|
||||
/// Wrapper around the optimal memory copy implementation for the target.
|
||||
RPC_ATTRS void rpc_memcpy(void *dst, const void *src, size_t count) {
|
||||
RPC_ATTRS void rpc_memcpy(void *dst, const void *src, uint64_t count) {
|
||||
__builtin_memcpy(dst, src, count);
|
||||
}
|
||||
|
||||
template <class T> RPC_ATTRS constexpr const T &max(const T &a, const T &b) {
|
||||
/// Minimal string length function.
|
||||
RPC_ATTRS constexpr uint64_t string_length(const char *s) {
|
||||
const char *end = s;
|
||||
for (; *end != '\0'; ++end)
|
||||
;
|
||||
return static_cast<uint64_t>(end - s + 1);
|
||||
}
|
||||
|
||||
/// Helper for dealing with function types.
|
||||
template <typename> struct function_traits;
|
||||
template <typename R, typename... Args> struct function_traits<R (*)(Args...)> {
|
||||
using return_type = R;
|
||||
using arg_types = rpc::tuple<Args...>;
|
||||
static constexpr uint64_t ARITY = sizeof...(Args);
|
||||
};
|
||||
|
||||
template <class T, class U> RPC_ATTRS constexpr T max(const T &a, const U &b) {
|
||||
return (a < b) ? b : a;
|
||||
}
|
||||
|
||||
template <class T, class U> RPC_ATTRS constexpr T min(const T &a, const U &b) {
|
||||
return (a < b) ? a : b;
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
|
||||
#endif // LLVM_LIBC_SHARED_RPC_UTIL_H
|
||||
|
||||
@ -1,66 +0,0 @@
|
||||
// RUN: %libomptarget-compilexx-run-and-check-generic
|
||||
|
||||
// REQUIRES: libc
|
||||
// REQUIRES: gpu
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
// CHECK: PASS
|
||||
|
||||
// This should be present in-tree relative to the test directory. If someone is
|
||||
// using a partial tree just pass the test.
|
||||
#if !__has_include(<../../libc/shared/rpc.h>)
|
||||
int main() { printf("PASS\n"); }
|
||||
#else
|
||||
#include <../../libc/shared/rpc.h>
|
||||
|
||||
extern "C" void __tgt_register_rpc_callback(unsigned (*Callback)(void *,
|
||||
unsigned));
|
||||
constexpr uint32_t RPC_TEST_OPCODE = 1;
|
||||
|
||||
template<uint32_t NumLanes> rpc::Status handleOpcodes(rpc::Server::Port &Port) {
|
||||
switch (Port.get_opcode()) {
|
||||
case RPC_TEST_OPCODE: {
|
||||
Port.recv(
|
||||
[&](rpc::Buffer *Buffer, uint32_t) { assert(Buffer->data[0] == 42); });
|
||||
Port.send([&](rpc::Buffer *, uint32_t) {});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return rpc::RPC_UNHANDLED_OPCODE;
|
||||
break;
|
||||
}
|
||||
return rpc::RPC_SUCCESS;
|
||||
}
|
||||
|
||||
static uint32_t handleOffloadOpcodes(void *Raw, uint32_t NumLanes) {
|
||||
rpc::Server::Port &Port = *reinterpret_cast<rpc::Server::Port *>(Raw);
|
||||
if (NumLanes == 1)
|
||||
return handleOpcodes<1>(Port);
|
||||
else if (NumLanes == 32)
|
||||
return handleOpcodes<32>(Port);
|
||||
else if (NumLanes == 64)
|
||||
return handleOpcodes<64>(Port);
|
||||
else
|
||||
return rpc::RPC_ERROR;
|
||||
}
|
||||
|
||||
[[gnu::weak]] rpc::Client client asm("__llvm_rpc_client");
|
||||
#pragma omp declare target to(client) device_type(nohost)
|
||||
|
||||
void __tgt_register_rpc_callback(unsigned (*Callback)(void *, unsigned));
|
||||
|
||||
int main() {
|
||||
__tgt_register_rpc_callback(&handleOffloadOpcodes);
|
||||
#pragma omp target
|
||||
{
|
||||
rpc::Client::Port Port = client.open<RPC_TEST_OPCODE>();
|
||||
Port.send([=](rpc::Buffer *buffer, uint32_t) { buffer->data[0] = 42; });
|
||||
Port.recv([](rpc::Buffer *, uint32_t) {});
|
||||
Port.close();
|
||||
}
|
||||
printf("PASS\n");
|
||||
}
|
||||
#endif
|
||||
205
offload/test/libc/rpc_callback.cpp
Normal file
205
offload/test/libc/rpc_callback.cpp
Normal file
@ -0,0 +1,205 @@
|
||||
// RUN: %libomptarget-compilexx-run-and-check-generic
|
||||
// REQUIRES: libc
|
||||
// REQUIRES: gpu
|
||||
|
||||
#include <assert.h>
|
||||
#include <omp.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
// CHECK: PASS
|
||||
|
||||
// If the RPC headers are not present, just pass the test.
|
||||
#if !__has_include(<../../libc/shared/rpc.h>)
|
||||
int main() { printf("PASS\n"); }
|
||||
#else
|
||||
|
||||
#include <../../libc/shared/rpc.h>
|
||||
#include <../../libc/shared/rpc_dispatch.h>
|
||||
|
||||
[[gnu::weak]] rpc::Client client asm("__llvm_rpc_client");
|
||||
#pragma omp declare target to(client) device_type(nohost)
|
||||
|
||||
//===------------------------------------------------------------------------===
|
||||
// Opcodes.
|
||||
//===------------------------------------------------------------------------===
|
||||
|
||||
constexpr uint32_t FOO_OPCODE = 1;
|
||||
constexpr uint32_t VOID_OPCODE = 2;
|
||||
constexpr uint32_t WRITEBACK_OPCODE = 3;
|
||||
constexpr uint32_t CONST_PTR_OPCODE = 4;
|
||||
constexpr uint32_t STRING_OPCODE = 5;
|
||||
constexpr uint32_t EMPTY_OPCODE = 6;
|
||||
constexpr uint32_t DIVERGENT_OPCODE = 7;
|
||||
|
||||
//===------------------------------------------------------------------------===
|
||||
// Server-side implementations.
|
||||
//===------------------------------------------------------------------------===
|
||||
|
||||
struct S {
|
||||
int arr[4];
|
||||
};
|
||||
|
||||
// 1. Non-pointer arguments, non-void return.
|
||||
int foo(int x, double d, char c) {
|
||||
assert(x == 42);
|
||||
assert(d == 0.0);
|
||||
assert(c == 'c');
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 2. Void return type.
|
||||
void void_fn(int x) { assert(x == 7); }
|
||||
|
||||
// 3. Write-back pointer.
|
||||
void writeback_fn(int *out) {
|
||||
assert(out != nullptr && *out == 42);
|
||||
*out = 99;
|
||||
}
|
||||
|
||||
// 4. Const pointer.
|
||||
int sum_const(const S *p) {
|
||||
int s = 0;
|
||||
for (int i = 0; i < 4; ++i)
|
||||
s += p->arr[i];
|
||||
return s;
|
||||
}
|
||||
|
||||
// 5. const char * string.
|
||||
int c_string(const char *s) {
|
||||
assert(s != nullptr);
|
||||
assert(strcmp(s, "hello") == 0);
|
||||
return strlen(s);
|
||||
}
|
||||
|
||||
// 6. Empty function.
|
||||
int empty() { return 42; }
|
||||
|
||||
// 7. Divergent values.
|
||||
void divergent(int *) {}
|
||||
|
||||
//===------------------------------------------------------------------------===
|
||||
// RPC client dispatch.
|
||||
//===------------------------------------------------------------------------===
|
||||
|
||||
#pragma omp begin declare variant match(device = {kind(gpu)})
|
||||
int foo(int x, double d, char c) {
|
||||
return rpc::dispatch<FOO_OPCODE>(client, foo, x, d, c);
|
||||
}
|
||||
|
||||
void void_fn(int x) { rpc::dispatch<VOID_OPCODE>(client, void_fn, x); }
|
||||
|
||||
void writeback_fn(int *out) {
|
||||
rpc::dispatch<WRITEBACK_OPCODE>(client, writeback_fn, out);
|
||||
}
|
||||
|
||||
int sum_const(const S *p) {
|
||||
return rpc::dispatch<CONST_PTR_OPCODE>(client, sum_const, p);
|
||||
}
|
||||
|
||||
int c_string(const char *s) {
|
||||
return rpc::dispatch<STRING_OPCODE>(client, c_string, s);
|
||||
}
|
||||
|
||||
int empty() { return rpc::dispatch<EMPTY_OPCODE>(client, empty); }
|
||||
|
||||
void divergent(int *p) {
|
||||
rpc::dispatch<DIVERGENT_OPCODE>(client, divergent, p);
|
||||
}
|
||||
#pragma omp end declare variant
|
||||
|
||||
//===------------------------------------------------------------------------===
|
||||
// RPC server dispatch.
|
||||
//===------------------------------------------------------------------------===
|
||||
|
||||
template <uint32_t NUM_LANES>
|
||||
rpc::Status handleOpcodesImpl(rpc::Server::Port &Port) {
|
||||
switch (Port.get_opcode()) {
|
||||
case FOO_OPCODE:
|
||||
rpc::invoke<NUM_LANES>(foo, Port);
|
||||
break;
|
||||
case VOID_OPCODE:
|
||||
rpc::invoke<NUM_LANES>(void_fn, Port);
|
||||
break;
|
||||
case WRITEBACK_OPCODE:
|
||||
rpc::invoke<NUM_LANES>(writeback_fn, Port);
|
||||
break;
|
||||
case CONST_PTR_OPCODE:
|
||||
rpc::invoke<NUM_LANES>(sum_const, Port);
|
||||
break;
|
||||
case STRING_OPCODE:
|
||||
rpc::invoke<NUM_LANES>(c_string, Port);
|
||||
break;
|
||||
case EMPTY_OPCODE:
|
||||
rpc::invoke<NUM_LANES>(empty, Port);
|
||||
break;
|
||||
case DIVERGENT_OPCODE:
|
||||
rpc::invoke<NUM_LANES>(divergent, Port);
|
||||
break;
|
||||
default:
|
||||
return rpc::RPC_UNHANDLED_OPCODE;
|
||||
}
|
||||
return rpc::RPC_SUCCESS;
|
||||
}
|
||||
|
||||
static uint32_t handleOpcodes(void *raw, uint32_t numLanes) {
|
||||
rpc::Server::Port &Port = *reinterpret_cast<rpc::Server::Port *>(raw);
|
||||
if (numLanes == 1)
|
||||
return handleOpcodesImpl<1>(Port);
|
||||
else if (numLanes == 32)
|
||||
return handleOpcodesImpl<32>(Port);
|
||||
else if (numLanes == 64)
|
||||
return handleOpcodesImpl<64>(Port);
|
||||
else
|
||||
return rpc::RPC_ERROR;
|
||||
}
|
||||
|
||||
extern "C" void __tgt_register_rpc_callback(unsigned (*callback)(void *,
|
||||
unsigned));
|
||||
|
||||
[[gnu::constructor]] void register_callback() {
|
||||
__tgt_register_rpc_callback(&handleOpcodes);
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
#pragma omp target
|
||||
#pragma omp parallel num_threads(32)
|
||||
{
|
||||
// 1. Non-pointer return.
|
||||
assert(foo(42, 0.0, 'c') == -1);
|
||||
|
||||
// 2. Void return.
|
||||
void_fn(7);
|
||||
|
||||
// 3. Write-back pointer.
|
||||
int value = 42;
|
||||
writeback_fn(&value);
|
||||
assert(value == 99);
|
||||
|
||||
// 4. Const pointer.
|
||||
S s{1, 2, 3, 4};
|
||||
int sum = sum_const(&s);
|
||||
assert(sum == 10);
|
||||
|
||||
// 5. const char * string.
|
||||
const char *msg = "hello";
|
||||
int len = c_string(msg);
|
||||
assert(len == 5);
|
||||
|
||||
// 6. No arguments.
|
||||
int ret = empty();
|
||||
assert(ret == 42);
|
||||
|
||||
// 7. Divergent values.
|
||||
int id = omp_get_thread_num();
|
||||
if (id % 2)
|
||||
divergent(&id);
|
||||
assert(id == omp_get_thread_num());
|
||||
}
|
||||
|
||||
printf("PASS\n");
|
||||
}
|
||||
|
||||
#endif
|
||||
Loading…
x
Reference in New Issue
Block a user