Eugene Zhulenev d4f1a3c6e2 [mlir] Add microbenchmark for linalg+async-parallel-for
Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D91896
2020-11-21 03:47:14 -08:00

289 lines
8.8 KiB
C++

//===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements basic Async runtime API for supporting Async dialect
// to LLVM dialect lowering.
//
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/AsyncRuntime.h"
#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
#include <atomic>
#include <cassert>
#include <condition_variable>
#include <functional>
#include <iostream>
#include <mutex>
#include <thread>
#include <vector>
#include "llvm/Support/ThreadPool.h"
//===----------------------------------------------------------------------===//
// Async runtime API.
//===----------------------------------------------------------------------===//
namespace {
// Forward declare class defined below.
class RefCounted;
// -------------------------------------------------------------------------- //
// AsyncRuntime orchestrates all async operations and Async runtime API is built
// on top of the default runtime instance.
// -------------------------------------------------------------------------- //
class AsyncRuntime {
public:
AsyncRuntime() : numRefCountedObjects(0) {}
~AsyncRuntime() {
assert(getNumRefCountedObjects() == 0 &&
"all ref counted objects must be destroyed");
}
int32_t getNumRefCountedObjects() {
return numRefCountedObjects.load(std::memory_order_relaxed);
}
private:
friend class RefCounted;
// Count the total number of reference counted objects in this instance
// of an AsyncRuntime. For debugging purposes only.
void addNumRefCountedObjects() {
numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
}
void dropNumRefCountedObjects() {
numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
}
std::atomic<int32_t> numRefCountedObjects;
};
// Returns the default per-process instance of an async runtime.
AsyncRuntime *getDefaultAsyncRuntimeInstance() {
static auto runtime = std::make_unique<AsyncRuntime>();
return runtime.get();
}
// -------------------------------------------------------------------------- //
// A base class for all reference counted objects created by the async runtime.
// -------------------------------------------------------------------------- //
class RefCounted {
public:
RefCounted(AsyncRuntime *runtime, int32_t refCount = 1)
: runtime(runtime), refCount(refCount) {
runtime->addNumRefCountedObjects();
}
virtual ~RefCounted() {
assert(refCount.load() == 0 && "reference count must be zero");
runtime->dropNumRefCountedObjects();
}
RefCounted(const RefCounted &) = delete;
RefCounted &operator=(const RefCounted &) = delete;
void addRef(int32_t count = 1) { refCount.fetch_add(count); }
void dropRef(int32_t count = 1) {
int32_t previous = refCount.fetch_sub(count);
assert(previous >= count && "reference count should not go below zero");
if (previous == count)
destroy();
}
protected:
virtual void destroy() { delete this; }
private:
AsyncRuntime *runtime;
std::atomic<int32_t> refCount;
};
} // namespace
struct AsyncToken : public RefCounted {
// AsyncToken created with a reference count of 2 because it will be returned
// to the `async.execute` caller and also will be later on emplaced by the
// asynchronously executed task. If the caller immediately will drop its
// reference we must ensure that the token will be alive until the
// asynchronous operation is completed.
AsyncToken(AsyncRuntime *runtime) : RefCounted(runtime, /*count=*/2) {}
// Internal state below guarded by a mutex.
std::mutex mu;
std::condition_variable cv;
bool ready = false;
std::vector<std::function<void()>> awaiters;
};
struct AsyncGroup : public RefCounted {
AsyncGroup(AsyncRuntime *runtime)
: RefCounted(runtime), pendingTokens(0), rank(0) {}
std::atomic<int> pendingTokens;
std::atomic<int> rank;
// Internal state below guarded by a mutex.
std::mutex mu;
std::condition_variable cv;
std::vector<std::function<void()>> awaiters;
};
// Adds references to reference counted runtime object.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int32_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
refCounted->addRef(count);
}
// Drops references from reference counted runtime object.
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int32_t count) {
RefCounted *refCounted = static_cast<RefCounted *>(ptr);
refCounted->dropRef(count);
}
// Create a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
AsyncToken *token = new AsyncToken(getDefaultAsyncRuntimeInstance());
return token;
}
// Create a new `async.group` in empty state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntimeInstance());
return group;
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
std::unique_lock<std::mutex> lockToken(token->mu);
std::unique_lock<std::mutex> lockGroup(group->mu);
// Get the rank of the token inside the group before we drop the reference.
int rank = group->rank.fetch_add(1);
group->pendingTokens.fetch_add(1);
auto onTokenReady = [group, token](bool dropRef) {
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
for (auto &awaiter : group->awaiters)
awaiter();
}
// We no longer need the token or the group, drop references on them.
if (dropRef) {
group->dropRef();
token->dropRef();
}
};
if (token->ready) {
onTokenReady(false);
} else {
group->addRef();
token->addRef();
token->awaiters.push_back([onTokenReady]() { onTokenReady(true); });
}
return rank;
}
// Switches `async.token` to ready state and runs all awaiters.
extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
token->ready = true;
token->cv.notify_all();
for (auto &awaiter : token->awaiters)
awaiter();
// Async tokens created with a ref count `2` to keep token alive until the
// async task completes. Drop this reference explicitly when token emplaced.
token->dropRef();
}
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
if (!token->ready)
token->cv.wait(lock, [token] { return token->ready; });
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
std::unique_lock<std::mutex> lock(group->mu);
if (group->pendingTokens != 0)
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
}
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
#if LLVM_ENABLE_THREADS
static llvm::ThreadPool *threadPool = new llvm::ThreadPool();
threadPool->async([handle, resume]() { (*resume)(handle); });
#else
(*resume)(handle);
#endif
}
extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
CoroHandle handle,
CoroResume resume) {
std::unique_lock<std::mutex> lock(token->mu);
auto execute = [handle, resume, token](bool dropRef) {
if (dropRef)
token->dropRef();
mlirAsyncRuntimeExecute(handle, resume);
};
if (token->ready) {
execute(false);
} else {
token->addRef();
token->awaiters.push_back([execute]() { execute(true); });
}
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
CoroResume resume) {
std::unique_lock<std::mutex> lock(group->mu);
auto execute = [handle, resume, group](bool dropRef) {
if (dropRef)
group->dropRef();
mlirAsyncRuntimeExecute(handle, resume);
};
if (group->pendingTokens == 0) {
execute(false);
} else {
group->addRef();
group->awaiters.push_back([execute]() { execute(true); });
}
}
//===----------------------------------------------------------------------===//
// Small async runtime support library for testing.
//===----------------------------------------------------------------------===//
extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
static thread_local std::thread::id thisId = std::this_thread::get_id();
std::cout << "Current thread id: " << thisId << "\n";
}
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS