[OFFLOAD] Add support for indexed per-thread containers (#164263)
Split from #158900 it adds a PerThreadContainer that can use STL-like indexed containers based on a slightly refactored PerThreadTable. --------- Co-authored-by: Joseph Huber <huberjn@outlook.com>
This commit is contained in:
parent
bd04ef6df5
commit
3f22ed1152
@ -160,17 +160,11 @@ struct InteropTableEntry {
|
||||
Interops.push_back(obj);
|
||||
}
|
||||
|
||||
template <class ClearFuncTy> void clear(ClearFuncTy f) {
|
||||
for (auto &Obj : Interops) {
|
||||
f(Obj);
|
||||
}
|
||||
}
|
||||
|
||||
/// vector interface
|
||||
int size() const { return Interops.size(); }
|
||||
iterator begin() { return Interops.begin(); }
|
||||
iterator end() { return Interops.end(); }
|
||||
iterator erase(iterator it) { return Interops.erase(it); }
|
||||
void clear() { Interops.clear(); }
|
||||
};
|
||||
|
||||
struct InteropTblTy
|
||||
|
||||
@ -14,101 +14,256 @@
|
||||
#define OFFLOAD_PERTHREADTABLE_H
|
||||
|
||||
#include <list>
|
||||
#include <llvm/ADT/SmallVector.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <type_traits>
|
||||
|
||||
template <typename ObjectType> class PerThread {
|
||||
std::mutex Mutex;
|
||||
llvm::SmallVector<std::shared_ptr<ObjectType>> ThreadDataList;
|
||||
|
||||
ObjectType &getThreadData() {
|
||||
static thread_local std::shared_ptr<ObjectType> ThreadData = nullptr;
|
||||
if (!ThreadData) {
|
||||
ThreadData = std::make_shared<ObjectType>();
|
||||
std::lock_guard<std::mutex> Lock(Mutex);
|
||||
ThreadDataList.push_back(ThreadData);
|
||||
}
|
||||
return *ThreadData;
|
||||
}
|
||||
|
||||
public:
|
||||
// Define default constructors, disable copy and move constructors.
|
||||
PerThread() = default;
|
||||
PerThread(const PerThread &) = delete;
|
||||
PerThread(PerThread &&) = delete;
|
||||
PerThread &operator=(const PerThread &) = delete;
|
||||
PerThread &operator=(PerThread &&) = delete;
|
||||
~PerThread() {
|
||||
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
|
||||
"Cannot be deleted while other threads are adding entries");
|
||||
ThreadDataList.clear();
|
||||
}
|
||||
|
||||
ObjectType &get() { return getThreadData(); }
|
||||
|
||||
template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
|
||||
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
|
||||
"Clear cannot be called while other threads are adding entries");
|
||||
for (std::shared_ptr<ObjectType> ThreadData : ThreadDataList) {
|
||||
if (!ThreadData)
|
||||
continue;
|
||||
ClearFunc(*ThreadData);
|
||||
}
|
||||
ThreadDataList.clear();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ContainerTy> struct ContainerConcepts {
|
||||
template <typename, template <typename> class, typename = std::void_t<>>
|
||||
struct has : std::false_type {};
|
||||
template <typename Ty, template <typename> class Op>
|
||||
struct has<Ty, Op, std::void_t<Op<Ty>>> : std::true_type {};
|
||||
|
||||
template <typename Ty> using IteratorTypeCheck = typename Ty::iterator;
|
||||
template <typename Ty> using MappedTypeCheck = typename Ty::mapped_type;
|
||||
template <typename Ty> using ValueTypeCheck = typename Ty::value_type;
|
||||
template <typename Ty> using KeyTypeCheck = typename Ty::key_type;
|
||||
template <typename Ty> using SizeTypeCheck = typename Ty::size_type;
|
||||
|
||||
template <typename Ty>
|
||||
using ClearCheck = decltype(std::declval<Ty>().clear());
|
||||
template <typename Ty>
|
||||
using ReserveCheck = decltype(std::declval<Ty>().reserve(1));
|
||||
template <typename Ty>
|
||||
using ResizeCheck = decltype(std::declval<Ty>().resize(1));
|
||||
|
||||
static constexpr bool hasIterator =
|
||||
has<ContainerTy, IteratorTypeCheck>::value;
|
||||
static constexpr bool hasClear = has<ContainerTy, ClearCheck>::value;
|
||||
static constexpr bool isAssociative =
|
||||
has<ContainerTy, MappedTypeCheck>::value;
|
||||
static constexpr bool hasReserve = has<ContainerTy, ReserveCheck>::value;
|
||||
static constexpr bool hasResize = has<ContainerTy, ResizeCheck>::value;
|
||||
|
||||
template <typename, template <typename> class, typename = std::void_t<>>
|
||||
struct has_type {
|
||||
using type = void;
|
||||
};
|
||||
template <typename Ty, template <typename> class Op>
|
||||
struct has_type<Ty, Op, std::void_t<Op<Ty>>> {
|
||||
using type = Op<Ty>;
|
||||
};
|
||||
|
||||
using iterator = typename has_type<ContainerTy, IteratorTypeCheck>::type;
|
||||
using value_type = typename std::conditional_t<
|
||||
isAssociative, typename has_type<ContainerTy, MappedTypeCheck>::type,
|
||||
typename has_type<ContainerTy, ValueTypeCheck>::type>;
|
||||
using key_type = typename std::conditional_t<
|
||||
isAssociative, typename has_type<ContainerTy, KeyTypeCheck>::type,
|
||||
typename has_type<ContainerTy, SizeTypeCheck>::type>;
|
||||
};
|
||||
|
||||
// Using an STL container (such as std::vector) indexed by thread ID has
|
||||
// too many race conditions issues so we store each thread entry into a
|
||||
// thread_local variable.
|
||||
// T is the container type used to store the objects, e.g., std::vector,
|
||||
// std::set, etc. by each thread. O is the type of the stored objects e.g.,
|
||||
// omp_interop_val_t *, ...
|
||||
|
||||
template <typename ContainerType, typename ObjectType> struct PerThreadTable {
|
||||
using iterator = typename ContainerType::iterator;
|
||||
// ContainerType is the container type used to store the objects, e.g.,
|
||||
// std::vector, std::set, etc. by each thread. ObjectType is the type of the
|
||||
// stored objects e.g., omp_interop_val_t *, ...
|
||||
template <typename ContainerType, typename ObjectType> class PerThreadTable {
|
||||
using iterator = typename ContainerConcepts<ContainerType>::iterator;
|
||||
|
||||
struct PerThreadData {
|
||||
size_t NElements = 0;
|
||||
std::unique_ptr<ContainerType> ThEntry;
|
||||
size_t Size = 0;
|
||||
std::unique_ptr<ContainerType> ThreadEntry;
|
||||
};
|
||||
|
||||
std::mutex Mtx;
|
||||
std::list<std::shared_ptr<PerThreadData>> ThreadDataList;
|
||||
std::mutex Mutex;
|
||||
llvm::SmallVector<std::shared_ptr<PerThreadData>> ThreadDataList;
|
||||
|
||||
// define default constructors, disable copy and move constructors
|
||||
PerThreadData &getThreadData() {
|
||||
static thread_local std::shared_ptr<PerThreadData> ThreadData = nullptr;
|
||||
if (!ThreadData) {
|
||||
ThreadData = std::make_shared<PerThreadData>();
|
||||
std::lock_guard<std::mutex> Lock(Mutex);
|
||||
ThreadDataList.push_back(ThreadData);
|
||||
}
|
||||
return *ThreadData;
|
||||
}
|
||||
|
||||
protected:
|
||||
ContainerType &getThreadEntry() {
|
||||
PerThreadData &ThreadData = getThreadData();
|
||||
if (ThreadData.ThreadEntry)
|
||||
return *ThreadData.ThreadEntry;
|
||||
ThreadData.ThreadEntry = std::make_unique<ContainerType>();
|
||||
return *ThreadData.ThreadEntry;
|
||||
}
|
||||
|
||||
size_t &getThreadSize() {
|
||||
PerThreadData &ThreadData = getThreadData();
|
||||
return ThreadData.Size;
|
||||
}
|
||||
|
||||
void setSize(size_t Size) {
|
||||
size_t &SizeRef = getThreadSize();
|
||||
SizeRef = Size;
|
||||
}
|
||||
|
||||
public:
|
||||
// define default constructors, disable copy and move constructors.
|
||||
PerThreadTable() = default;
|
||||
PerThreadTable(const PerThreadTable &) = delete;
|
||||
PerThreadTable(PerThreadTable &&) = delete;
|
||||
PerThreadTable &operator=(const PerThreadTable &) = delete;
|
||||
PerThreadTable &operator=(PerThreadTable &&) = delete;
|
||||
~PerThreadTable() {
|
||||
std::lock_guard<std::mutex> Lock(Mtx);
|
||||
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
|
||||
"Cannot be deleted while other threads are adding entries");
|
||||
ThreadDataList.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
PerThreadData &getThreadData() {
|
||||
static thread_local std::shared_ptr<PerThreadData> ThData = nullptr;
|
||||
if (!ThData) {
|
||||
ThData = std::make_shared<PerThreadData>();
|
||||
std::lock_guard<std::mutex> Lock(Mtx);
|
||||
ThreadDataList.push_back(ThData);
|
||||
}
|
||||
return *ThData;
|
||||
}
|
||||
|
||||
protected:
|
||||
ContainerType &getThreadEntry() {
|
||||
auto &ThData = getThreadData();
|
||||
if (ThData.ThEntry)
|
||||
return *ThData.ThEntry;
|
||||
ThData.ThEntry = std::make_unique<ContainerType>();
|
||||
return *ThData.ThEntry;
|
||||
}
|
||||
|
||||
size_t &getThreadNElements() {
|
||||
auto &ThData = getThreadData();
|
||||
return ThData.NElements;
|
||||
}
|
||||
|
||||
public:
|
||||
void add(ObjectType obj) {
|
||||
auto &Entry = getThreadEntry();
|
||||
auto &NElements = getThreadNElements();
|
||||
NElements++;
|
||||
ContainerType &Entry = getThreadEntry();
|
||||
size_t &SizeRef = getThreadSize();
|
||||
SizeRef++;
|
||||
Entry.add(obj);
|
||||
}
|
||||
|
||||
iterator erase(iterator it) {
|
||||
auto &Entry = getThreadEntry();
|
||||
auto &NElements = getThreadNElements();
|
||||
NElements--;
|
||||
ContainerType &Entry = getThreadEntry();
|
||||
size_t &SizeRef = getThreadSize();
|
||||
SizeRef--;
|
||||
return Entry.erase(it);
|
||||
}
|
||||
|
||||
size_t size() { return getThreadNElements(); }
|
||||
size_t size() { return getThreadSize(); }
|
||||
|
||||
// Iterators to traverse objects owned by
|
||||
// the current thread
|
||||
// the current thread.
|
||||
iterator begin() {
|
||||
auto &Entry = getThreadEntry();
|
||||
ContainerType &Entry = getThreadEntry();
|
||||
return Entry.begin();
|
||||
}
|
||||
iterator end() {
|
||||
auto &Entry = getThreadEntry();
|
||||
ContainerType &Entry = getThreadEntry();
|
||||
return Entry.end();
|
||||
}
|
||||
|
||||
template <class F> void clear(F f) {
|
||||
std::lock_guard<std::mutex> Lock(Mtx);
|
||||
for (auto ThData : ThreadDataList) {
|
||||
if (!ThData->ThEntry || ThData->NElements == 0)
|
||||
template <class ClearFuncTy> void clear(ClearFuncTy ClearFunc) {
|
||||
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
|
||||
"Clear cannot be called while other threads are adding entries");
|
||||
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
|
||||
if (!ThreadData->ThreadEntry || ThreadData->Size == 0)
|
||||
continue;
|
||||
ThData->ThEntry->clear(f);
|
||||
ThData->NElements = 0;
|
||||
if constexpr (ContainerConcepts<ContainerType>::hasIterator &&
|
||||
ContainerConcepts<ContainerType>::hasClear) {
|
||||
for (auto &Obj : *ThreadData->ThreadEntry) {
|
||||
if constexpr (ContainerConcepts<ContainerType>::isAssociative) {
|
||||
ClearFunc(Obj.second);
|
||||
} else {
|
||||
ClearFunc(Obj);
|
||||
}
|
||||
}
|
||||
ThreadData->ThreadEntry->clear();
|
||||
} else {
|
||||
static_assert(true, "Container type not supported");
|
||||
}
|
||||
ThreadData->Size = 0;
|
||||
}
|
||||
ThreadDataList.clear();
|
||||
}
|
||||
|
||||
template <class DeinitFuncTy> llvm::Error deinit(DeinitFuncTy DeinitFunc) {
|
||||
assert(Mutex.try_lock() && (Mutex.unlock(), true) &&
|
||||
"Deinit cannot be called while other threads are adding entries");
|
||||
for (std::shared_ptr<PerThreadData> ThreadData : ThreadDataList) {
|
||||
if (!ThreadData->ThreadEntry || ThreadData->Size == 0)
|
||||
continue;
|
||||
for (auto &Obj : *ThreadData->ThreadEntry) {
|
||||
if constexpr (ContainerConcepts<ContainerType>::isAssociative) {
|
||||
if (auto Err = DeinitFunc(Obj.second))
|
||||
return Err;
|
||||
} else {
|
||||
if (auto Err = DeinitFunc(Obj))
|
||||
return Err;
|
||||
}
|
||||
}
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ContainerType, size_t ReserveSize = 0>
|
||||
class PerThreadContainer
|
||||
: public PerThreadTable<ContainerType, typename ContainerConcepts<
|
||||
ContainerType>::value_type> {
|
||||
|
||||
using IndexType = typename ContainerConcepts<ContainerType>::key_type;
|
||||
using ObjectType = typename ContainerConcepts<ContainerType>::value_type;
|
||||
|
||||
public:
|
||||
// Get the object for the given index in the current thread.
|
||||
ObjectType &get(IndexType Index) {
|
||||
ContainerType &Entry = this->getThreadEntry();
|
||||
|
||||
// Specialized code for vector-like containers.
|
||||
if constexpr (ContainerConcepts<ContainerType>::hasResize) {
|
||||
if (Index >= Entry.size()) {
|
||||
if constexpr (ContainerConcepts<ContainerType>::hasReserve &&
|
||||
ReserveSize > 0)
|
||||
Entry.reserve(ReserveSize);
|
||||
|
||||
// If the index is out of bounds, try resize the container.
|
||||
Entry.resize(Index + 1);
|
||||
}
|
||||
}
|
||||
ObjectType &Ret = Entry[Index];
|
||||
this->setSize(Entry.size());
|
||||
return Ret;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user