
Depends On D89958 1. Adds `async.group`/`async.awaitall` to group together multiple async tokens/values 2. Rewrite scf.parallel operation into multiple concurrent async.execute operations over non overlapping subranges of the original loop. Example: ``` scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) { "do_some_compute"(%i, %j): () -> () } ``` Converted to: ``` %c0 = constant 0 : index %c1 = constant 1 : index // Compute blocks sizes for each induction variable. %num_blocks_i = ... : index %num_blocks_j = ... : index %block_size_i = ... : index %block_size_j = ... : index // Create an async group to track async execute ops. %group = async.create_group scf.for %bi = %c0 to %num_blocks_i step %c1 { %block_start_i = ... : index %block_end_i = ... : index scf.for %bj = %c0 t0 %num_blocks_j step %c1 { %block_start_j = ... : index %block_end_j = ... : index // Execute the body of original parallel operation for the current // block. %token = async.execute { scf.for %i = %block_start_i to %block_end_i step %si { scf.for %j = %block_start_j to %block_end_j step %sj { "do_some_compute"(%i, %j): () -> () } } } // Add produced async token to the group. async.add_to_group %token, %group } } // Await completion of all async.execute operations. async.await_all %group ``` In this example outer loop launches inner block level loops as separate async execute operations which will be executed concurrently. At the end it waits for the completiom of all async execute operations. Reviewed By: ftynse, mehdi_amini Differential Revision: https://reviews.llvm.org/D89963
152 lines
4.5 KiB
C++
152 lines
4.5 KiB
C++
//===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements basic Async runtime API for supporting Async dialect
|
|
// to LLVM dialect lowering.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/ExecutionEngine/AsyncRuntime.h"
|
|
|
|
#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|
|
|
|
#include <atomic>
|
|
#include <condition_variable>
|
|
#include <functional>
|
|
#include <iostream>
|
|
#include <mutex>
|
|
#include <thread>
|
|
#include <vector>
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Async runtime API.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct AsyncToken {
|
|
bool ready = false;
|
|
std::mutex mu;
|
|
std::condition_variable cv;
|
|
std::vector<std::function<void()>> awaiters;
|
|
};
|
|
|
|
struct AsyncGroup {
|
|
std::atomic<int> pendingTokens{0};
|
|
std::atomic<int> rank{0};
|
|
std::mutex mu;
|
|
std::condition_variable cv;
|
|
std::vector<std::function<void()>> awaiters;
|
|
};
|
|
|
|
// Create a new `async.token` in not-ready state.
|
|
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
|
|
AsyncToken *token = new AsyncToken;
|
|
return token;
|
|
}
|
|
|
|
// Create a new `async.group` in empty state.
|
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
|
|
AsyncGroup *group = new AsyncGroup;
|
|
return group;
|
|
}
|
|
|
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
|
|
mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
|
|
std::unique_lock<std::mutex> lockToken(token->mu);
|
|
std::unique_lock<std::mutex> lockGroup(group->mu);
|
|
|
|
group->pendingTokens.fetch_add(1);
|
|
|
|
auto onTokenReady = [group]() {
|
|
// Run all group awaiters if it was the last token in the group.
|
|
if (group->pendingTokens.fetch_sub(1) == 1) {
|
|
group->cv.notify_all();
|
|
for (auto &awaiter : group->awaiters)
|
|
awaiter();
|
|
}
|
|
};
|
|
|
|
if (token->ready)
|
|
onTokenReady();
|
|
else
|
|
token->awaiters.push_back([onTokenReady]() { onTokenReady(); });
|
|
|
|
return group->rank.fetch_add(1);
|
|
}
|
|
|
|
// Switches `async.token` to ready state and runs all awaiters.
|
|
extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
|
|
std::unique_lock<std::mutex> lock(token->mu);
|
|
token->ready = true;
|
|
token->cv.notify_all();
|
|
for (auto &awaiter : token->awaiters)
|
|
awaiter();
|
|
}
|
|
|
|
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
|
|
std::unique_lock<std::mutex> lock(token->mu);
|
|
if (!token->ready)
|
|
token->cv.wait(lock, [token] { return token->ready; });
|
|
}
|
|
|
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
|
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
|
|
std::unique_lock<std::mutex> lock(group->mu);
|
|
if (group->pendingTokens != 0)
|
|
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
|
|
}
|
|
|
|
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
|
|
#if LLVM_ENABLE_THREADS
|
|
std::thread thread([handle, resume]() { (*resume)(handle); });
|
|
thread.detach();
|
|
#else
|
|
(*resume)(handle);
|
|
#endif
|
|
}
|
|
|
|
extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
|
|
CoroHandle handle,
|
|
CoroResume resume) {
|
|
std::unique_lock<std::mutex> lock(token->mu);
|
|
|
|
auto execute = [handle, resume]() {
|
|
mlirAsyncRuntimeExecute(handle, resume);
|
|
};
|
|
|
|
if (token->ready)
|
|
execute();
|
|
else
|
|
token->awaiters.push_back([execute]() { execute(); });
|
|
}
|
|
|
|
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
|
|
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
|
|
CoroResume resume) {
|
|
std::unique_lock<std::mutex> lock(group->mu);
|
|
|
|
auto execute = [handle, resume]() {
|
|
mlirAsyncRuntimeExecute(handle, resume);
|
|
};
|
|
|
|
if (group->pendingTokens == 0)
|
|
execute();
|
|
else
|
|
group->awaiters.push_back([execute]() { execute(); });
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Small async runtime support library for testing.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
|
|
static thread_local std::thread::id thisId = std::this_thread::get_id();
|
|
std::cout << "Current thread id: " << thisId << "\n";
|
|
}
|
|
|
|
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
|