Eugene Zhulenev c30ab6c2a3 [mlir] Transform scf.parallel to scf.for + async.execute
Depends On D89958

1. Adds `async.group`/`async.awaitall` to group together multiple async tokens/values
2. Rewrite scf.parallel operation into multiple concurrent async.execute operations over non overlapping subranges of the original loop.

Example:

```
   scf.for (%i, %j) = (%lbi, %lbj) to (%ubi, %ubj) step (%si, %sj) {
     "do_some_compute"(%i, %j): () -> ()
   }
```

Converted to:

```
   %c0 = constant 0 : index
   %c1 = constant 1 : index

   // Compute blocks sizes for each induction variable.
   %num_blocks_i = ... : index
   %num_blocks_j = ... : index
   %block_size_i = ... : index
   %block_size_j = ... : index

   // Create an async group to track async execute ops.
   %group = async.create_group

   scf.for %bi = %c0 to %num_blocks_i step %c1 {
     %block_start_i = ... : index
     %block_end_i   = ... : index

     scf.for %bj = %c0 t0 %num_blocks_j step %c1 {
       %block_start_j = ... : index
       %block_end_j   = ... : index

       // Execute the body of original parallel operation for the current
       // block.
       %token = async.execute {
         scf.for %i = %block_start_i to %block_end_i step %si {
           scf.for %j = %block_start_j to %block_end_j step %sj {
             "do_some_compute"(%i, %j): () -> ()
           }
         }
       }

       // Add produced async token to the group.
       async.add_to_group %token, %group
     }
   }

   // Await completion of all async.execute operations.
   async.await_all %group
```
In this example outer loop launches inner block level loops as separate async
execute operations which will be executed concurrently.

At the end it waits for the completiom of all async execute operations.

Reviewed By: ftynse, mehdi_amini

Differential Revision: https://reviews.llvm.org/D89963
2020-11-13 04:02:56 -08:00

152 lines
4.5 KiB
C++

//===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements basic Async runtime API for supporting Async dialect
// to LLVM dialect lowering.
//
//===----------------------------------------------------------------------===//
#include "mlir/ExecutionEngine/AsyncRuntime.h"
#ifdef MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS
#include <atomic>
#include <condition_variable>
#include <functional>
#include <iostream>
#include <mutex>
#include <thread>
#include <vector>
//===----------------------------------------------------------------------===//
// Async runtime API.
//===----------------------------------------------------------------------===//
struct AsyncToken {
bool ready = false;
std::mutex mu;
std::condition_variable cv;
std::vector<std::function<void()>> awaiters;
};
struct AsyncGroup {
std::atomic<int> pendingTokens{0};
std::atomic<int> rank{0};
std::mutex mu;
std::condition_variable cv;
std::vector<std::function<void()>> awaiters;
};
// Create a new `async.token` in not-ready state.
extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
AsyncToken *token = new AsyncToken;
return token;
}
// Create a new `async.group` in empty state.
extern "C" MLIR_ASYNCRUNTIME_EXPORT AsyncGroup *mlirAsyncRuntimeCreateGroup() {
AsyncGroup *group = new AsyncGroup;
return group;
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT int64_t
mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
std::unique_lock<std::mutex> lockToken(token->mu);
std::unique_lock<std::mutex> lockGroup(group->mu);
group->pendingTokens.fetch_add(1);
auto onTokenReady = [group]() {
// Run all group awaiters if it was the last token in the group.
if (group->pendingTokens.fetch_sub(1) == 1) {
group->cv.notify_all();
for (auto &awaiter : group->awaiters)
awaiter();
}
};
if (token->ready)
onTokenReady();
else
token->awaiters.push_back([onTokenReady]() { onTokenReady(); });
return group->rank.fetch_add(1);
}
// Switches `async.token` to ready state and runs all awaiters.
extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
token->ready = true;
token->cv.notify_all();
for (auto &awaiter : token->awaiters)
awaiter();
}
extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
std::unique_lock<std::mutex> lock(token->mu);
if (!token->ready)
token->cv.wait(lock, [token] { return token->ready; });
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
std::unique_lock<std::mutex> lock(group->mu);
if (group->pendingTokens != 0)
group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
}
extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
#if LLVM_ENABLE_THREADS
std::thread thread([handle, resume]() { (*resume)(handle); });
thread.detach();
#else
(*resume)(handle);
#endif
}
extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
CoroHandle handle,
CoroResume resume) {
std::unique_lock<std::mutex> lock(token->mu);
auto execute = [handle, resume]() {
mlirAsyncRuntimeExecute(handle, resume);
};
if (token->ready)
execute();
else
token->awaiters.push_back([execute]() { execute(); });
}
extern "C" MLIR_ASYNCRUNTIME_EXPORT void
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group, CoroHandle handle,
CoroResume resume) {
std::unique_lock<std::mutex> lock(group->mu);
auto execute = [handle, resume]() {
mlirAsyncRuntimeExecute(handle, resume);
};
if (group->pendingTokens == 0)
execute();
else
group->awaiters.push_back([execute]() { execute(); });
}
//===----------------------------------------------------------------------===//
// Small async runtime support library for testing.
//===----------------------------------------------------------------------===//
extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
static thread_local std::thread::id thisId = std::this_thread::get_id();
std::cout << "Current thread id: " << thisId << "\n";
}
#endif // MLIR_ASYNCRUNTIME_DEFINE_FUNCTIONS