[MLIR][py] Add PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (#130109)
In some projects like JAX ir.Context are used with disabled multi-threading to avoid
caching multiple threading pools:
623865fe95/jax/_src/interpreters/mlir.py (L606-L611)
However, when context has enabled multithreading it also uses locks on
the StorageUniquers and this can be helpful to avoid data races in the
multi-threaded execution (for example with free-threaded cpython,
https://github.com/jax-ml/jax/issues/26272).
With this PR user can enable the multi-threading: 1) enables additional
locking and 2) set a shared threading pool such that cached contexts can
have one global pool.
This commit is contained in:
parent
78060a7df7
commit
ab18cc246c
@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
|
||||
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
|
||||
MlirLlvmThreadPool threadPool);
|
||||
|
||||
/// Gets the number of threads of the thread pool of the context when
|
||||
/// multithreading is enabled. Returns 1 if no multithreading.
|
||||
MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context);
|
||||
|
||||
/// Gets the thread pool of the context when enabled multithreading, otherwise
|
||||
/// an assertion is raised.
|
||||
MLIR_CAPI_EXPORTED MlirLlvmThreadPool
|
||||
mlirContextGetThreadPool(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2743,6 +2743,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
// __init__.py will subclass it with site-specific functionality and set a
|
||||
// "Context" attribute on this module.
|
||||
//----------------------------------------------------------------------------
|
||||
|
||||
// Expose DefaultThreadPool to python
|
||||
nb::class_<PyThreadPool>(m, "ThreadPool")
|
||||
.def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); })
|
||||
.def("get_max_concurrency", &PyThreadPool::getMaxConcurrency)
|
||||
.def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr);
|
||||
|
||||
nb::class_<PyMlirContext>(m, "_BaseContext")
|
||||
.def("__init__",
|
||||
[](PyMlirContext &self) {
|
||||
@ -2814,6 +2821,25 @@ void mlir::python::populateIRCore(nb::module_ &m) {
|
||||
mlirContextEnableMultithreading(self.get(), enable);
|
||||
},
|
||||
nb::arg("enable"))
|
||||
.def("set_thread_pool",
|
||||
[](PyMlirContext &self, PyThreadPool &pool) {
|
||||
// we should disable multi-threading first before setting
|
||||
// new thread pool otherwise the assert in
|
||||
// MLIRContext::setThreadPool will be raised.
|
||||
mlirContextEnableMultithreading(self.get(), false);
|
||||
mlirContextSetThreadPool(self.get(), pool.get());
|
||||
})
|
||||
.def("get_num_threads",
|
||||
[](PyMlirContext &self) {
|
||||
return mlirContextGetNumThreads(self.get());
|
||||
})
|
||||
.def("_mlir_thread_pool_ptr",
|
||||
[](PyMlirContext &self) {
|
||||
MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get());
|
||||
std::stringstream ss;
|
||||
ss << pool.ptr;
|
||||
return ss.str();
|
||||
})
|
||||
.def(
|
||||
"is_registered_operation",
|
||||
[](PyMlirContext &self, std::string &name) {
|
||||
|
@ -11,6 +11,7 @@
|
||||
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
|
||||
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -22,9 +23,10 @@
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/IntegerSet.h"
|
||||
#include "mlir-c/Transforms.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
#include "mlir/Bindings/Python/Nanobind.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/ThreadPool.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
@ -158,6 +160,29 @@ private:
|
||||
FrameKind frameKind;
|
||||
};
|
||||
|
||||
/// Wrapper around MlirLlvmThreadPool
|
||||
/// Python object owns the C++ thread pool
|
||||
class PyThreadPool {
|
||||
public:
|
||||
PyThreadPool() {
|
||||
ownedThreadPool = std::make_unique<llvm::DefaultThreadPool>();
|
||||
}
|
||||
PyThreadPool(const PyThreadPool &) = delete;
|
||||
PyThreadPool(PyThreadPool &&) = delete;
|
||||
|
||||
int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); }
|
||||
MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); }
|
||||
|
||||
std::string _mlir_thread_pool_ptr() const {
|
||||
std::stringstream ss;
|
||||
ss << ownedThreadPool.get();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<llvm::ThreadPoolInterface> ownedThreadPool;
|
||||
};
|
||||
|
||||
/// Wrapper around MlirContext.
|
||||
using PyMlirContextRef = PyObjectRef<PyMlirContext>;
|
||||
class PyMlirContext {
|
||||
|
@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context,
|
||||
unwrap(context)->setThreadPool(*unwrap(threadPool));
|
||||
}
|
||||
|
||||
unsigned mlirContextGetNumThreads(MlirContext context) {
|
||||
return unwrap(context)->getNumThreads();
|
||||
}
|
||||
|
||||
MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) {
|
||||
return wrap(&unwrap(context)->getThreadPool());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -148,13 +148,25 @@ def _site_initialize():
|
||||
break
|
||||
|
||||
class Context(ir._BaseContext):
|
||||
def __init__(self, load_on_create_dialects=None, *args, **kwargs):
|
||||
def __init__(
|
||||
self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.append_dialect_registry(get_dialect_registry())
|
||||
for hook in post_init_hooks:
|
||||
hook(self)
|
||||
if disable_multithreading and thread_pool is not None:
|
||||
raise ValueError(
|
||||
"Context constructor has given thread_pool argument, "
|
||||
"but disable_multithreading flag is True. "
|
||||
"Please, set thread_pool argument to None or "
|
||||
"set disable_multithreading flag to False."
|
||||
)
|
||||
if not disable_multithreading:
|
||||
self.enable_multithreading(True)
|
||||
if thread_pool is None:
|
||||
self.enable_multithreading(True)
|
||||
else:
|
||||
self.set_thread_pool(thread_pool)
|
||||
if load_on_create_dialects is not None:
|
||||
logger.debug(
|
||||
"Loading all dialects from load_on_create_dialects arg %r",
|
||||
|
@ -47,3 +47,26 @@ c4_capsule = c4._CAPIPtr
|
||||
assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule)
|
||||
c5 = mlir.ir.Context._CAPICreate(c4_capsule)
|
||||
assert c4 is c5
|
||||
c4 = None
|
||||
c5 = None
|
||||
gc.collect()
|
||||
|
||||
# Create a global threadpool and use it in two contexts
|
||||
tp = mlir.ir.ThreadPool()
|
||||
assert tp.get_max_concurrency() > 0
|
||||
c5 = mlir.ir.Context()
|
||||
c5.set_thread_pool(tp)
|
||||
assert c5.get_num_threads() == tp.get_max_concurrency()
|
||||
assert c5._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
|
||||
c6 = mlir.ir.Context()
|
||||
c6.set_thread_pool(tp)
|
||||
assert c6.get_num_threads() == tp.get_max_concurrency()
|
||||
assert c6._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
|
||||
c7 = mlir.ir.Context(thread_pool=tp)
|
||||
assert c7.get_num_threads() == tp.get_max_concurrency()
|
||||
assert c7._mlir_thread_pool_ptr() == tp._mlir_thread_pool_ptr()
|
||||
assert mlir.ir.Context._get_live_count() == 3
|
||||
c5 = None
|
||||
c6 = None
|
||||
c7 = None
|
||||
gc.collect()
|
||||
|
Loading…
x
Reference in New Issue
Block a user