[mlir python] Port Python core code to nanobind. (#118583)

Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.

For a complicated Google-internal LLM model in JAX, this change improves
the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.

To a large extent, this is a mechanical change, for instance changing
`pybind11::`
to `nanobind::`.

Notes:
* this PR needs Nanobind 2.4.0, because it needs a bug fix
(https://github.com/wjakob/nanobind/pull/806) that landed in that
release.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in `PybindAdapters.h`. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now
defined in nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of a
similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple of
places I added code to support truthy values during casting.
* nanobind distinguishes bytes (`nb::bytes`) from strings (e.g.,
`std::string`). This required nb::bytes overloads in a few places.
This commit is contained in:
Peter Hawkins 2024-12-18 14:16:11 -05:00 committed by GitHub
parent bfd05102d8
commit 41bd35b58b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1870 additions and 1570 deletions

View File

@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages)
"extension = '${PYTHON_MODULE_EXTENSION}")
mlir_detect_nanobind_install()
find_package(nanobind 2.2 CONFIG REQUIRED)
find_package(nanobind 2.4 CONFIG REQUIRED)
message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
"suffix = '${PYTHON_MODULE_SUFFIX}', "

View File

@ -9,7 +9,7 @@
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace mlir {

View File

@ -374,9 +374,8 @@ public:
static_assert(!std::is_member_function_pointer<Func>::value,
"def_staticmethod(...) called with a non-static member "
"function pointer");
py::cpp_function cf(
std::forward<Func>(f), py::name(name), py::scope(thisClass),
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
py::cpp_function cf(std::forward<Func>(f), py::name(name),
py::scope(thisClass), extra...);
thisClass.attr(cf.name()) = py::staticmethod(cf);
return *this;
}
@ -387,9 +386,8 @@ public:
static_assert(!std::is_member_function_pointer<Func>::value,
"def_classmethod(...) called with a non-static member "
"function pointer");
py::cpp_function cf(
std::forward<Func>(f), py::name(name), py::scope(thisClass),
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
py::cpp_function cf(std::forward<Func>(f), py::name(name),
py::scope(thisClass), extra...);
thisClass.attr(cf.name()) =
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
return *this;

View File

@ -9,18 +9,17 @@
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
#include "PybindUtils.h"
#include <optional>
#include <string>
#include <vector>
#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
#include <optional>
#include <string>
#include <vector>
namespace mlir {
namespace python {
@ -57,55 +56,55 @@ public:
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
pybind11::function pyFunc,
nanobind::callable pyFunc,
bool replace = false);
/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
bool replace = false);
/// Adds a user-friendly value caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
nanobind::callable valueCaster,
bool replace = false);
/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
void registerDialectImpl(const std::string &dialectNamespace,
pybind11::object pyClass);
nanobind::object pyClass);
/// Adds a concrete implementation operation class.
/// Raises an exception if the mapping already exists and replace == false.
/// This is intended to be called by implementation code.
void registerOperationImpl(const std::string &operationName,
pybind11::object pyClass, bool replace = false);
nanobind::object pyClass, bool replace = false);
/// Returns the custom Attribute builder for Attribute kind.
std::optional<pybind11::function>
std::optional<nanobind::callable>
lookupAttributeBuilder(const std::string &attributeKind);
/// Returns the custom type caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
/// Returns the custom value caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);
/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
std::optional<nanobind::object>
lookupDialectClass(const std::string &dialectNamespace);
/// Looks up a registered operation class (deriving from OpView) by operation
/// name. Note that this may trigger a load of the dialect, which can
/// arbitrarily re-enter.
std::optional<pybind11::object>
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
private:
@ -113,15 +112,15 @@ private:
/// Module name prefixes to search under for dialect implementation modules.
std::vector<std::string> dialectSearchPrefixes;
/// Map of dialect namespace to external dialect class object.
llvm::StringMap<pybind11::object> dialectClassMap;
llvm::StringMap<nanobind::object> dialectClassMap;
/// Map of full operation name to external operation class object.
llvm::StringMap<pybind11::object> operationClassMap;
llvm::StringMap<nanobind::object> operationClassMap;
/// Map of attribute ODS name to custom builder.
llvm::StringMap<pybind11::object> attributeBuilderMap;
llvm::StringMap<nanobind::callable> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
/// Map of MlirTypeID to custom value caster.
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;

View File

@ -6,20 +6,19 @@
//
//===----------------------------------------------------------------------===//
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <cstddef>
#include <cstdint>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
#include "IRModule.h"
#include "PybindUtils.h"
#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Bindings/Python/Interop.h"
@ -30,7 +29,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
@ -46,23 +45,23 @@ static const char kDumpDocstring[] =
/// Throws errors in case of failure, using "action" to describe what the caller
/// was attempting to do.
template <typename PyType, typename CType>
static void pyListToVector(const py::list &list,
static void pyListToVector(const nb::list &list,
llvm::SmallVectorImpl<CType> &result,
StringRef action) {
result.reserve(py::len(list));
for (py::handle item : list) {
result.reserve(nb::len(list));
for (nb::handle item : list) {
try {
result.push_back(item.cast<PyType>());
} catch (py::cast_error &err) {
result.push_back(nb::cast<PyType>(item));
} catch (nb::cast_error &err) {
std::string msg = (llvm::Twine("Invalid expression when ") + action +
" (" + err.what() + ")")
.str();
throw py::cast_error(msg);
} catch (py::reference_cast_error &err) {
throw std::runtime_error(msg.c_str());
} catch (std::runtime_error &err) {
std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
action + " (" + err.what() + ")")
.str();
throw py::cast_error(msg);
throw std::runtime_error(msg.c_str());
}
}
}
@ -94,7 +93,7 @@ public:
// IsAFunctionTy isaFunction
// const char *pyClassName
// and redefine bindDerived.
using ClassTy = py::class_<DerivedTy, BaseTy>;
using ClassTy = nb::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirAffineExpr);
PyConcreteAffineExpr() = default;
@ -105,24 +104,25 @@ public:
static MlirAffineExpr castFrom(PyAffineExpr &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
throw py::value_error((Twine("Cannot cast affine expression to ") +
auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
throw nb::value_error((Twine("Cannot cast affine expression to ") +
DerivedTy::pyClassName + " (from " + origRepr +
")")
.str());
.str()
.c_str());
}
return orig;
}
static void bind(py::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
static void bind(nb::module_ &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
cls.def_static(
"isinstance",
[](PyAffineExpr &otherAffineExpr) -> bool {
return DerivedTy::isaFunction(otherAffineExpr);
},
py::arg("other"));
nb::arg("other"));
DerivedTy::bindDerived(cls);
}
@ -144,9 +144,9 @@ public:
}
static void bindDerived(ClassTy &c) {
c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
py::arg("context") = py::none());
c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"),
nb::arg("context").none() = nb::none());
c.def_prop_ro("value", [](PyAffineConstantExpr &self) {
return mlirAffineConstantExprGetValue(self);
});
}
@ -164,9 +164,9 @@ public:
}
static void bindDerived(ClassTy &c) {
c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
py::arg("context") = py::none());
c.def_property_readonly("position", [](PyAffineDimExpr &self) {
c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"),
nb::arg("context").none() = nb::none());
c.def_prop_ro("position", [](PyAffineDimExpr &self) {
return mlirAffineDimExprGetPosition(self);
});
}
@ -184,9 +184,9 @@ public:
}
static void bindDerived(ClassTy &c) {
c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
py::arg("context") = py::none());
c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"),
nb::arg("context").none() = nb::none());
c.def_prop_ro("position", [](PyAffineSymbolExpr &self) {
return mlirAffineSymbolExprGetPosition(self);
});
}
@ -209,8 +209,8 @@ public:
}
static void bindDerived(ClassTy &c) {
c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs);
c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs);
}
};
@ -365,15 +365,14 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
return mlirAffineExprEqual(affineExpr, other.affineExpr);
}
py::object PyAffineExpr::getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonAffineExprToCapsule(*this));
nb::object PyAffineExpr::getCapsule() {
return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this));
}
PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) {
MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
if (mlirAffineExprIsNull(rawAffineExpr))
throw py::error_already_set();
throw nb::python_error();
return PyAffineExpr(
PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
rawAffineExpr);
@ -424,14 +423,14 @@ bool PyAffineMap::operator==(const PyAffineMap &other) const {
return mlirAffineMapEqual(affineMap, other.affineMap);
}
py::object PyAffineMap::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
nb::object PyAffineMap::getCapsule() {
return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this));
}
PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) {
MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
if (mlirAffineMapIsNull(rawAffineMap))
throw py::error_already_set();
throw nb::python_error();
return PyAffineMap(
PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
rawAffineMap);
@ -454,11 +453,10 @@ public:
bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
static void bind(py::module &m) {
py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
py::module_local())
.def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
.def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
static void bind(nb::module_ &m) {
nb::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
.def_prop_ro("expr", &PyIntegerSetConstraint::getExpr)
.def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq);
}
private:
@ -501,27 +499,25 @@ bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
return mlirIntegerSetEqual(integerSet, other.integerSet);
}
py::object PyIntegerSet::getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonIntegerSetToCapsule(*this));
nb::object PyIntegerSet::getCapsule() {
return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this));
}
PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) {
MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
if (mlirIntegerSetIsNull(rawIntegerSet))
throw py::error_already_set();
throw nb::python_error();
return PyIntegerSet(
PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
rawIntegerSet);
}
void mlir::python::populateIRAffine(py::module &m) {
void mlir::python::populateIRAffine(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineExpr and derived classes.
//----------------------------------------------------------------------------
py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyAffineExpr::getCapsule)
nb::class_<PyAffineExpr>(m, "AffineExpr")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
.def("__add__", &PyAffineAddExpr::get)
.def("__add__", &PyAffineAddExpr::getRHSConstant)
@ -558,7 +554,7 @@ void mlir::python::populateIRAffine(py::module &m) {
.def("__eq__", [](PyAffineExpr &self,
PyAffineExpr &other) { return self == other; })
.def("__eq__",
[](PyAffineExpr &self, py::object &other) { return false; })
[](PyAffineExpr &self, nb::object &other) { return false; })
.def("__str__",
[](PyAffineExpr &self) {
PyPrintAccumulator printAccum;
@ -579,7 +575,7 @@ void mlir::python::populateIRAffine(py::module &m) {
[](PyAffineExpr &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def_property_readonly(
.def_prop_ro(
"context",
[](PyAffineExpr &self) { return self.getContext().getObject(); })
.def("compose",
@ -632,16 +628,16 @@ void mlir::python::populateIRAffine(py::module &m) {
.def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
"Gets an affine expression containing the rounded-up result "
"of dividing an expression by a constant.")
.def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
py::arg("context") = py::none(),
.def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"),
nb::arg("context").none() = nb::none(),
"Gets a constant affine expression with the given value.")
.def_static(
"get_dim", &PyAffineDimExpr::get, py::arg("position"),
py::arg("context") = py::none(),
"get_dim", &PyAffineDimExpr::get, nb::arg("position"),
nb::arg("context").none() = nb::none(),
"Gets an affine expression of a dimension at the given position.")
.def_static(
"get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
py::arg("context") = py::none(),
"get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"),
nb::arg("context").none() = nb::none(),
"Gets an affine expression of a symbol at the given position.")
.def(
"dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
@ -659,13 +655,12 @@ void mlir::python::populateIRAffine(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of PyAffineMap.
//----------------------------------------------------------------------------
py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyAffineMap::getCapsule)
nb::class_<PyAffineMap>(m, "AffineMap")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
.def("__eq__",
[](PyAffineMap &self, PyAffineMap &other) { return self == other; })
.def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
.def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; })
.def("__str__",
[](PyAffineMap &self) {
PyPrintAccumulator printAccum;
@ -687,7 +682,7 @@ void mlir::python::populateIRAffine(py::module &m) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def_static("compress_unused_symbols",
[](py::list affineMaps, DefaultingPyMlirContext context) {
[](nb::list affineMaps, DefaultingPyMlirContext context) {
SmallVector<MlirAffineMap> maps;
pyListToVector<PyAffineMap, MlirAffineMap>(
affineMaps, maps, "attempting to create an AffineMap");
@ -704,7 +699,7 @@ void mlir::python::populateIRAffine(py::module &m) {
res.emplace_back(context->getRef(), m);
return res;
})
.def_property_readonly(
.def_prop_ro(
"context",
[](PyAffineMap &self) { return self.getContext().getObject(); },
"Context that owns the Affine Map")
@ -713,7 +708,7 @@ void mlir::python::populateIRAffine(py::module &m) {
kDumpDocstring)
.def_static(
"get",
[](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
[](intptr_t dimCount, intptr_t symbolCount, nb::list exprs,
DefaultingPyMlirContext context) {
SmallVector<MlirAffineExpr> affineExprs;
pyListToVector<PyAffineExpr, MlirAffineExpr>(
@ -723,8 +718,8 @@ void mlir::python::populateIRAffine(py::module &m) {
affineExprs.size(), affineExprs.data());
return PyAffineMap(context->getRef(), map);
},
py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
py::arg("context") = py::none(),
nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"),
nb::arg("context").none() = nb::none(),
"Gets a map with the given expressions as results.")
.def_static(
"get_constant",
@ -733,7 +728,7 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirAffineMapConstantGet(context->get(), value);
return PyAffineMap(context->getRef(), affineMap);
},
py::arg("value"), py::arg("context") = py::none(),
nb::arg("value"), nb::arg("context").none() = nb::none(),
"Gets an affine map with a single constant result")
.def_static(
"get_empty",
@ -741,7 +736,7 @@ void mlir::python::populateIRAffine(py::module &m) {
MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
return PyAffineMap(context->getRef(), affineMap);
},
py::arg("context") = py::none(), "Gets an empty affine map.")
nb::arg("context").none() = nb::none(), "Gets an empty affine map.")
.def_static(
"get_identity",
[](intptr_t nDims, DefaultingPyMlirContext context) {
@ -749,7 +744,7 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
return PyAffineMap(context->getRef(), affineMap);
},
py::arg("n_dims"), py::arg("context") = py::none(),
nb::arg("n_dims"), nb::arg("context").none() = nb::none(),
"Gets an identity map with the given number of dimensions.")
.def_static(
"get_minor_identity",
@ -759,8 +754,8 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
return PyAffineMap(context->getRef(), affineMap);
},
py::arg("n_dims"), py::arg("n_results"),
py::arg("context") = py::none(),
nb::arg("n_dims"), nb::arg("n_results"),
nb::arg("context").none() = nb::none(),
"Gets a minor identity map with the given number of dimensions and "
"results.")
.def_static(
@ -768,13 +763,13 @@ void mlir::python::populateIRAffine(py::module &m) {
[](std::vector<unsigned> permutation,
DefaultingPyMlirContext context) {
if (!isPermutation(permutation))
throw py::cast_error("Invalid permutation when attempting to "
"create an AffineMap");
throw std::runtime_error("Invalid permutation when attempting to "
"create an AffineMap");
MlirAffineMap affineMap = mlirAffineMapPermutationGet(
context->get(), permutation.size(), permutation.data());
return PyAffineMap(context->getRef(), affineMap);
},
py::arg("permutation"), py::arg("context") = py::none(),
nb::arg("permutation"), nb::arg("context").none() = nb::none(),
"Gets an affine map that permutes its inputs.")
.def(
"get_submap",
@ -782,33 +777,33 @@ void mlir::python::populateIRAffine(py::module &m) {
intptr_t numResults = mlirAffineMapGetNumResults(self);
for (intptr_t pos : resultPos) {
if (pos < 0 || pos >= numResults)
throw py::value_error("result position out of bounds");
throw nb::value_error("result position out of bounds");
}
MlirAffineMap affineMap = mlirAffineMapGetSubMap(
self, resultPos.size(), resultPos.data());
return PyAffineMap(self.getContext(), affineMap);
},
py::arg("result_positions"))
nb::arg("result_positions"))
.def(
"get_major_submap",
[](PyAffineMap &self, intptr_t nResults) {
if (nResults >= mlirAffineMapGetNumResults(self))
throw py::value_error("number of results out of bounds");
throw nb::value_error("number of results out of bounds");
MlirAffineMap affineMap =
mlirAffineMapGetMajorSubMap(self, nResults);
return PyAffineMap(self.getContext(), affineMap);
},
py::arg("n_results"))
nb::arg("n_results"))
.def(
"get_minor_submap",
[](PyAffineMap &self, intptr_t nResults) {
if (nResults >= mlirAffineMapGetNumResults(self))
throw py::value_error("number of results out of bounds");
throw nb::value_error("number of results out of bounds");
MlirAffineMap affineMap =
mlirAffineMapGetMinorSubMap(self, nResults);
return PyAffineMap(self.getContext(), affineMap);
},
py::arg("n_results"))
nb::arg("n_results"))
.def(
"replace",
[](PyAffineMap &self, PyAffineExpr &expression,
@ -818,39 +813,37 @@ void mlir::python::populateIRAffine(py::module &m) {
self, expression, replacement, numResultDims, numResultSyms);
return PyAffineMap(self.getContext(), affineMap);
},
py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
py::arg("n_result_syms"))
.def_property_readonly(
nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"),
nb::arg("n_result_syms"))
.def_prop_ro(
"is_permutation",
[](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
.def_property_readonly("is_projected_permutation",
[](PyAffineMap &self) {
return mlirAffineMapIsProjectedPermutation(self);
})
.def_property_readonly(
.def_prop_ro("is_projected_permutation",
[](PyAffineMap &self) {
return mlirAffineMapIsProjectedPermutation(self);
})
.def_prop_ro(
"n_dims",
[](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
.def_property_readonly(
.def_prop_ro(
"n_inputs",
[](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
.def_property_readonly(
.def_prop_ro(
"n_symbols",
[](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
.def_property_readonly("results", [](PyAffineMap &self) {
return PyAffineMapExprList(self);
});
.def_prop_ro("results",
[](PyAffineMap &self) { return PyAffineMapExprList(self); });
PyAffineMapExprList::bind(m);
//----------------------------------------------------------------------------
// Mapping of PyIntegerSet.
//----------------------------------------------------------------------------
py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyIntegerSet::getCapsule)
nb::class_<PyIntegerSet>(m, "IntegerSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
.def("__eq__", [](PyIntegerSet &self,
PyIntegerSet &other) { return self == other; })
.def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
.def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; })
.def("__str__",
[](PyIntegerSet &self) {
PyPrintAccumulator printAccum;
@ -871,7 +864,7 @@ void mlir::python::populateIRAffine(py::module &m) {
[](PyIntegerSet &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def_property_readonly(
.def_prop_ro(
"context",
[](PyIntegerSet &self) { return self.getContext().getObject(); })
.def(
@ -879,14 +872,14 @@ void mlir::python::populateIRAffine(py::module &m) {
kDumpDocstring)
.def_static(
"get",
[](intptr_t numDims, intptr_t numSymbols, py::list exprs,
[](intptr_t numDims, intptr_t numSymbols, nb::list exprs,
std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
if (exprs.size() != eqFlags.size())
throw py::value_error(
throw nb::value_error(
"Expected the number of constraints to match "
"that of equality flags");
if (exprs.empty())
throw py::value_error("Expected non-empty list of constraints");
if (exprs.size() == 0)
throw nb::value_error("Expected non-empty list of constraints");
// Copy over to a SmallVector because std::vector has a
// specialization for booleans that packs data and does not
@ -901,8 +894,8 @@ void mlir::python::populateIRAffine(py::module &m) {
affineExprs.data(), flags.data());
return PyIntegerSet(context->getRef(), set);
},
py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
py::arg("eq_flags"), py::arg("context") = py::none())
nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"),
nb::arg("eq_flags"), nb::arg("context").none() = nb::none())
.def_static(
"get_empty",
[](intptr_t numDims, intptr_t numSymbols,
@ -911,20 +904,20 @@ void mlir::python::populateIRAffine(py::module &m) {
mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
return PyIntegerSet(context->getRef(), set);
},
py::arg("num_dims"), py::arg("num_symbols"),
py::arg("context") = py::none())
nb::arg("num_dims"), nb::arg("num_symbols"),
nb::arg("context").none() = nb::none())
.def(
"get_replaced",
[](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
[](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs,
intptr_t numResultDims, intptr_t numResultSymbols) {
if (static_cast<intptr_t>(dimExprs.size()) !=
mlirIntegerSetGetNumDims(self))
throw py::value_error(
throw nb::value_error(
"Expected the number of dimension replacement expressions "
"to match that of dimensions");
if (static_cast<intptr_t>(symbolExprs.size()) !=
mlirIntegerSetGetNumSymbols(self))
throw py::value_error(
throw nb::value_error(
"Expected the number of symbol replacement expressions "
"to match that of symbols");
@ -940,30 +933,30 @@ void mlir::python::populateIRAffine(py::module &m) {
numResultDims, numResultSymbols);
return PyIntegerSet(self.getContext(), set);
},
py::arg("dim_exprs"), py::arg("symbol_exprs"),
py::arg("num_result_dims"), py::arg("num_result_symbols"))
.def_property_readonly("is_canonical_empty",
[](PyIntegerSet &self) {
return mlirIntegerSetIsCanonicalEmpty(self);
})
.def_property_readonly(
nb::arg("dim_exprs"), nb::arg("symbol_exprs"),
nb::arg("num_result_dims"), nb::arg("num_result_symbols"))
.def_prop_ro("is_canonical_empty",
[](PyIntegerSet &self) {
return mlirIntegerSetIsCanonicalEmpty(self);
})
.def_prop_ro(
"n_dims",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
.def_property_readonly(
.def_prop_ro(
"n_symbols",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
.def_property_readonly(
.def_prop_ro(
"n_inputs",
[](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
.def_property_readonly("n_equalities",
[](PyIntegerSet &self) {
return mlirIntegerSetGetNumEqualities(self);
})
.def_property_readonly("n_inequalities",
[](PyIntegerSet &self) {
return mlirIntegerSetGetNumInequalities(self);
})
.def_property_readonly("constraints", [](PyIntegerSet &self) {
.def_prop_ro("n_equalities",
[](PyIntegerSet &self) {
return mlirIntegerSetGetNumEqualities(self);
})
.def_prop_ro("n_inequalities",
[](PyIntegerSet &self) {
return mlirIntegerSetGetNumInequalities(self);
})
.def_prop_ro("constraints", [](PyIntegerSet &self) {
return PyIntegerSetConstraintList(self);
});
PyIntegerSetConstraint::bind(m);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/vector.h>
#include <cstdint>
#include <optional>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <string>
#include <utility>
#include <vector>
@ -24,7 +24,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
namespace py = pybind11;
namespace nb = nanobind;
namespace mlir {
namespace python {
@ -53,10 +53,10 @@ namespace {
/// Takes in an optional ist of operands and converts them into a SmallVector
/// of MlirVlaues. Returns an empty SmallVector if the list is empty.
llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
llvm::SmallVector<MlirValue> mlirOperands;
if (!operandList || operandList->empty()) {
if (!operandList || operandList->size() == 0) {
return mlirOperands;
}
@ -68,40 +68,42 @@ llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
PyValue *val;
try {
val = py::cast<PyValue *>(it.value());
val = nb::cast<PyValue *>(it.value());
if (!val)
throw py::cast_error();
throw nb::cast_error();
mlirOperands.push_back(val->get());
continue;
} catch (py::cast_error &err) {
} catch (nb::cast_error &err) {
// Intentionally unhandled to try sequence below first.
(void)err;
}
try {
auto vals = py::cast<py::sequence>(it.value());
for (py::object v : vals) {
auto vals = nb::cast<nb::sequence>(it.value());
for (nb::handle v : vals) {
try {
val = py::cast<PyValue *>(v);
val = nb::cast<PyValue *>(v);
if (!val)
throw py::cast_error();
throw nb::cast_error();
mlirOperands.push_back(val->get());
} catch (py::cast_error &err) {
throw py::value_error(
} catch (nb::cast_error &err) {
throw nb::value_error(
(llvm::Twine("Operand ") + llvm::Twine(it.index()) +
" must be a Value or Sequence of Values (" + err.what() + ")")
.str());
.str()
.c_str());
}
}
continue;
} catch (py::cast_error &err) {
throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
} catch (nb::cast_error &err) {
throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
" must be a Value or Sequence of Values (" +
err.what() + ")")
.str());
.str()
.c_str());
}
throw py::cast_error();
throw nb::cast_error();
}
return mlirOperands;
@ -144,24 +146,24 @@ wrapRegions(std::optional<std::vector<PyRegion>> regions) {
template <typename ConcreteIface>
class PyConcreteOpInterface {
protected:
using ClassTy = py::class_<ConcreteIface>;
using ClassTy = nb::class_<ConcreteIface>;
using GetTypeIDFunctionTy = MlirTypeID (*)();
public:
/// Constructs an interface instance from an object that is either an
/// operation or a subclass of OpView. In the latter case, only the static
/// methods of the interface are accessible to the caller.
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
: obj(std::move(object)) {
try {
operation = &py::cast<PyOperation &>(obj);
} catch (py::cast_error &) {
operation = &nb::cast<PyOperation &>(obj);
} catch (nb::cast_error &) {
// Do nothing.
}
try {
operation = &py::cast<PyOpView &>(obj).getOperation();
} catch (py::cast_error &) {
operation = &nb::cast<PyOpView &>(obj).getOperation();
} catch (nb::cast_error &) {
// Do nothing.
}
@ -169,7 +171,7 @@ public:
if (!mlirOperationImplementsInterface(*operation,
ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw py::value_error(msg + ConcreteIface::pyClassName);
throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
}
MlirIdentifier identifier = mlirOperationGetName(*operation);
@ -177,9 +179,9 @@ public:
opName = std::string(stringRef.data, stringRef.length);
} else {
try {
opName = obj.attr("OPERATION_NAME").template cast<std::string>();
} catch (py::cast_error &) {
throw py::type_error(
opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
} catch (nb::cast_error &) {
throw nb::type_error(
"Op interface does not refer to an operation or OpView class");
}
@ -187,22 +189,19 @@ public:
mlirStringRefCreate(opName.data(), opName.length()),
context.resolve().get(), ConcreteIface::getInterfaceID())) {
std::string msg = "the operation does not implement ";
throw py::value_error(msg + ConcreteIface::pyClassName);
throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
}
}
}
/// Creates the Python bindings for this class in the given module.
static void bind(py::module &m) {
py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
py::module_local());
cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
py::arg("context") = py::none(), constructorDoc)
.def_property_readonly("operation",
&PyConcreteOpInterface::getOperationObject,
operationDoc)
.def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
opviewDoc);
static void bind(nb::module_ &m) {
nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
nb::arg("context").none() = nb::none(), constructorDoc)
.def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
operationDoc)
.def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
ConcreteIface::bindDerived(cls);
}
@ -216,9 +215,9 @@ public:
/// Returns the operation instance from which this object was constructed.
/// Throws a type error if this object was constructed from a subclass of
/// OpView.
py::object getOperationObject() {
nb::object getOperationObject() {
if (operation == nullptr) {
throw py::type_error("Cannot get an operation from a static interface");
throw nb::type_error("Cannot get an operation from a static interface");
}
return operation->getRef().releaseObject();
@ -227,9 +226,9 @@ public:
/// Returns the opview of the operation instance from which this object was
/// constructed. Throws a type error if this object was constructed form a
/// subclass of OpView.
py::object getOpView() {
nb::object getOpView() {
if (operation == nullptr) {
throw py::type_error("Cannot get an opview from a static interface");
throw nb::type_error("Cannot get an opview from a static interface");
}
return operation->createOpView();
@ -242,7 +241,7 @@ public:
private:
PyOperation *operation = nullptr;
std::string opName;
py::object obj;
nb::object obj;
};
/// Python wrapper for InferTypeOpInterface. This interface has only static
@ -276,7 +275,7 @@ public:
/// Given the arguments required to build an operation, attempts to infer its
/// return types. Throws value_error on failure.
std::vector<PyType>
inferReturnTypes(std::optional<py::list> operandList,
inferReturnTypes(std::optional<nb::list> operandList,
std::optional<PyAttribute> attributes, void *properties,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context,
@ -299,7 +298,7 @@ public:
mlirRegions.data(), &appendResultsCallback, &data);
if (mlirLogicalResultIsFailure(result)) {
throw py::value_error("Failed to infer result types");
throw nb::value_error("Failed to infer result types");
}
return inferredTypes;
@ -307,11 +306,12 @@ public:
static void bindDerived(ClassTy &cls) {
cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
py::arg("operands") = py::none(),
py::arg("attributes") = py::none(),
py::arg("properties") = py::none(), py::arg("regions") = py::none(),
py::arg("context") = py::none(), py::arg("loc") = py::none(),
inferReturnTypesDoc);
nb::arg("operands").none() = nb::none(),
nb::arg("attributes").none() = nb::none(),
nb::arg("properties").none() = nb::none(),
nb::arg("regions").none() = nb::none(),
nb::arg("context").none() = nb::none(),
nb::arg("loc").none() = nb::none(), inferReturnTypesDoc);
}
};
@ -319,9 +319,9 @@ public:
class PyShapedTypeComponents {
public:
PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
PyShapedTypeComponents(py::list shape, MlirType elementType)
PyShapedTypeComponents(nb::list shape, MlirType elementType)
: shape(std::move(shape)), elementType(elementType), ranked(true) {}
PyShapedTypeComponents(py::list shape, MlirType elementType,
PyShapedTypeComponents(nb::list shape, MlirType elementType,
MlirAttribute attribute)
: shape(std::move(shape)), elementType(elementType), attribute(attribute),
ranked(true) {}
@ -330,10 +330,9 @@ public:
: shape(other.shape), elementType(other.elementType),
attribute(other.attribute), ranked(other.ranked) {}
static void bind(py::module &m) {
py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
py::module_local())
.def_property_readonly(
static void bind(nb::module_ &m) {
nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
.def_prop_ro(
"element_type",
[](PyShapedTypeComponents &self) { return self.elementType; },
"Returns the element type of the shaped type components.")
@ -342,57 +341,57 @@ public:
[](PyType &elementType) {
return PyShapedTypeComponents(elementType);
},
py::arg("element_type"),
nb::arg("element_type"),
"Create an shaped type components object with only the element "
"type.")
.def_static(
"get",
[](py::list shape, PyType &elementType) {
[](nb::list shape, PyType &elementType) {
return PyShapedTypeComponents(std::move(shape), elementType);
},
py::arg("shape"), py::arg("element_type"),
nb::arg("shape"), nb::arg("element_type"),
"Create a ranked shaped type components object.")
.def_static(
"get",
[](py::list shape, PyType &elementType, PyAttribute &attribute) {
[](nb::list shape, PyType &elementType, PyAttribute &attribute) {
return PyShapedTypeComponents(std::move(shape), elementType,
attribute);
},
py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
"Create a ranked shaped type components object with attribute.")
.def_property_readonly(
.def_prop_ro(
"has_rank",
[](PyShapedTypeComponents &self) -> bool { return self.ranked; },
"Returns whether the given shaped type component is ranked.")
.def_property_readonly(
.def_prop_ro(
"rank",
[](PyShapedTypeComponents &self) -> py::object {
[](PyShapedTypeComponents &self) -> nb::object {
if (!self.ranked) {
return py::none();
return nb::none();
}
return py::int_(self.shape.size());
return nb::int_(self.shape.size());
},
"Returns the rank of the given ranked shaped type components. If "
"the shaped type components does not have a rank, None is "
"returned.")
.def_property_readonly(
.def_prop_ro(
"shape",
[](PyShapedTypeComponents &self) -> py::object {
[](PyShapedTypeComponents &self) -> nb::object {
if (!self.ranked) {
return py::none();
return nb::none();
}
return py::list(self.shape);
return nb::list(self.shape);
},
"Returns the shape of the ranked shaped type components as a list "
"of integers. Returns none if the shaped type component does not "
"have a rank.");
}
pybind11::object getCapsule();
static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
nb::object getCapsule();
static PyShapedTypeComponents createFromCapsule(nb::object capsule);
private:
py::list shape;
nb::list shape;
MlirType elementType;
MlirAttribute attribute;
bool ranked{false};
@ -424,7 +423,7 @@ public:
if (!hasRank) {
data->inferredShapedTypeComponents.emplace_back(elementType);
} else {
py::list shapeList;
nb::list shapeList;
for (intptr_t i = 0; i < rank; ++i) {
shapeList.append(shape[i]);
}
@ -436,7 +435,7 @@ public:
/// Given the arguments required to build an operation, attempts to infer the
/// shaped type components. Throws value_error on failure.
std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
std::optional<py::list> operandList,
std::optional<nb::list> operandList,
std::optional<PyAttribute> attributes, void *properties,
std::optional<std::vector<PyRegion>> regions,
DefaultingPyMlirContext context, DefaultingPyLocation location) {
@ -458,7 +457,7 @@ public:
mlirRegions.data(), &appendResultsCallback, &data);
if (mlirLogicalResultIsFailure(result)) {
throw py::value_error("Failed to infer result shape type components");
throw nb::value_error("Failed to infer result shape type components");
}
return inferredShapedTypeComponents;
@ -467,14 +466,16 @@ public:
static void bindDerived(ClassTy &cls) {
cls.def("inferReturnTypeComponents",
&PyInferShapedTypeOpInterface::inferReturnTypeComponents,
py::arg("operands") = py::none(),
py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
py::arg("properties") = py::none(), py::arg("context") = py::none(),
py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
nb::arg("operands").none() = nb::none(),
nb::arg("attributes").none() = nb::none(),
nb::arg("regions").none() = nb::none(),
nb::arg("properties").none() = nb::none(),
nb::arg("context").none() = nb::none(),
nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc);
}
};
void populateIRInterfaces(py::module &m) {
void populateIRInterfaces(nb::module_ &m) {
PyInferTypeOpInterface::bind(m);
PyShapedTypeComponents::bind(m);
PyInferShapedTypeOpInterface::bind(m);

View File

@ -7,16 +7,19 @@
//===----------------------------------------------------------------------===//
#include "IRModule.h"
#include "Globals.h"
#include "PybindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <optional>
#include <vector>
namespace py = pybind11;
#include "Globals.h"
#include "NanobindUtils.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
@ -41,14 +44,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
return true;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded = py::none();
nb::object loaded = nb::none();
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
try {
loaded = py::module::import(moduleName.c_str());
} catch (py::error_already_set &e) {
loaded = nb::module_::import_(moduleName.c_str());
} catch (nb::python_error &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
continue;
}
@ -66,41 +69,39 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
py::function pyFunc, bool replace) {
py::object &found = attributeBuilderMap[attributeKind];
nb::callable pyFunc, bool replace) {
nb::object &found = attributeBuilderMap[attributeKind];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind +
"' is already registered with func: " +
py::str(found).operator std::string())
nb::cast<std::string>(nb::str(found)))
.str());
}
found = std::move(pyFunc);
}
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
pybind11::function typeCaster,
bool replace) {
pybind11::object &found = typeCasterMap[mlirTypeID];
nb::callable typeCaster, bool replace) {
nb::object &found = typeCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Type caster is already registered with caster: " +
py::str(found).operator std::string());
nb::cast<std::string>(nb::str(found)));
found = std::move(typeCaster);
}
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
bool replace) {
pybind11::object &found = valueCasterMap[mlirTypeID];
nb::callable valueCaster, bool replace) {
nb::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
py::repr(found).cast<std::string>());
nb::cast<std::string>(nb::repr(found)));
found = std::move(valueCaster);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
nb::object pyClass) {
nb::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
dialectNamespace + "' is already registered.")
@ -110,8 +111,8 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass, bool replace) {
py::object &found = operationClassMap[operationName];
nb::object pyClass, bool replace) {
nb::object &found = operationClassMap[operationName];
if (found && !replace) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
@ -120,7 +121,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
found = std::move(pyClass);
}
std::optional<py::function>
std::optional<nb::callable>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
@ -130,7 +131,7 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
return std::nullopt;
}
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
@ -142,7 +143,7 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
return std::nullopt;
}
std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
// Try to load dialect module.
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
@ -154,7 +155,7 @@ std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
return std::nullopt;
}
std::optional<py::object>
std::optional<nb::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
if (!loadDialectModule(dialectNamespace))
@ -168,7 +169,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
return std::nullopt;
}
std::optional<pybind11::object>
std::optional<nb::object>
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Make sure dialect module is loaded.
auto split = operationName.split('.');

View File

@ -10,20 +10,22 @@
#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include <optional>
#include <utility>
#include <vector>
#include "Globals.h"
#include "PybindUtils.h"
#include "NanobindUtils.h"
#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Transforms.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "llvm/ADT/DenseMap.h"
namespace mlir {
@ -49,7 +51,7 @@ class PyValue;
template <typename T>
class PyObjectRef {
public:
PyObjectRef(T *referrent, pybind11::object object)
PyObjectRef(T *referrent, nanobind::object object)
: referrent(referrent), object(std::move(object)) {
assert(this->referrent &&
"cannot construct PyObjectRef with null referrent");
@ -67,13 +69,13 @@ public:
int getRefCount() {
if (!object)
return 0;
return object.ref_count();
return Py_REFCNT(object.ptr());
}
/// Releases the object held by this instance, returning it.
/// This is the proper thing to return from a function that wants to return
/// the reference. Note that this does not work from initializers.
pybind11::object releaseObject() {
nanobind::object releaseObject() {
assert(referrent && object);
referrent = nullptr;
auto stolen = std::move(object);
@ -85,7 +87,7 @@ public:
assert(referrent && object);
return referrent;
}
pybind11::object getObject() {
nanobind::object getObject() {
assert(referrent && object);
return object;
}
@ -93,7 +95,7 @@ public:
private:
T *referrent;
pybind11::object object;
nanobind::object object;
};
/// Tracks an entry in the thread context stack. New entries are pushed onto
@ -112,9 +114,9 @@ public:
Location,
};
PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
pybind11::object insertionPoint,
pybind11::object location)
PyThreadContextEntry(FrameKind frameKind, nanobind::object context,
nanobind::object insertionPoint,
nanobind::object location)
: context(std::move(context)), insertionPoint(std::move(insertionPoint)),
location(std::move(location)), frameKind(frameKind) {}
@ -134,26 +136,26 @@ public:
/// Stack management.
static PyThreadContextEntry *getTopOfStack();
static pybind11::object pushContext(PyMlirContext &context);
static nanobind::object pushContext(nanobind::object context);
static void popContext(PyMlirContext &context);
static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint);
static void popInsertionPoint(PyInsertionPoint &insertionPoint);
static pybind11::object pushLocation(PyLocation &location);
static nanobind::object pushLocation(nanobind::object location);
static void popLocation(PyLocation &location);
/// Gets the thread local stack.
static std::vector<PyThreadContextEntry> &getStack();
private:
static void push(FrameKind frameKind, pybind11::object context,
pybind11::object insertionPoint, pybind11::object location);
static void push(FrameKind frameKind, nanobind::object context,
nanobind::object insertionPoint, nanobind::object location);
/// An object reference to the PyContext.
pybind11::object context;
nanobind::object context;
/// An object reference to the current insertion point.
pybind11::object insertionPoint;
nanobind::object insertionPoint;
/// An object reference to the current location.
pybind11::object location;
nanobind::object location;
// The kind of push that was performed.
FrameKind frameKind;
};
@ -163,14 +165,15 @@ using PyMlirContextRef = PyObjectRef<PyMlirContext>;
class PyMlirContext {
public:
PyMlirContext() = delete;
PyMlirContext(MlirContext context);
PyMlirContext(const PyMlirContext &) = delete;
PyMlirContext(PyMlirContext &&) = delete;
/// For the case of a python __init__ (py::init) method, pybind11 is quite
/// strict about needing to return a pointer that is not yet associated to
/// an py::object. Since the forContext() method acts like a pool, possibly
/// returning a recycled context, it does not satisfy this need. The usual
/// way in python to accomplish such a thing is to override __new__, but
/// For the case of a python __init__ (nanobind::init) method, pybind11 is
/// quite strict about needing to return a pointer that is not yet associated
/// to an nanobind::object. Since the forContext() method acts like a pool,
/// possibly returning a recycled context, it does not satisfy this need. The
/// usual way in python to accomplish such a thing is to override __new__, but
/// that is also not supported by pybind11. Instead, we use this entry
/// point which always constructs a fresh context (which cannot alias an
/// existing one because it is fresh).
@ -187,17 +190,17 @@ public:
/// Gets a strong reference to this context, which will ensure it is kept
/// alive for the life of the reference.
PyMlirContextRef getRef() {
return PyMlirContextRef(this, pybind11::cast(this));
return PyMlirContextRef(this, nanobind::cast(this));
}
/// Gets a capsule wrapping the void* within the MlirContext.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
/// Note that PyMlirContext instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirContext
/// is taken by calling this function.
static pybind11::object createFromCapsule(pybind11::object capsule);
static nanobind::object createFromCapsule(nanobind::object capsule);
/// Gets the count of live context objects. Used for testing.
static size_t getLiveCount();
@ -237,14 +240,14 @@ public:
size_t getLiveModuleCount();
/// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(const pybind11::object &excType,
const pybind11::object &excVal,
const pybind11::object &excTb);
static nanobind::object contextEnter(nanobind::object context);
void contextExit(const nanobind::object &excType,
const nanobind::object &excVal,
const nanobind::object &excTb);
/// Attaches a Python callback as a diagnostic handler, returning a
/// registration object (internally a PyDiagnosticHandler).
pybind11::object attachDiagnosticHandler(pybind11::object callback);
nanobind::object attachDiagnosticHandler(nanobind::object callback);
/// Controls whether error diagnostics should be propagated to diagnostic
/// handlers, instead of being captured by `ErrorCapture`.
@ -252,8 +255,6 @@ public:
struct ErrorCapture;
private:
PyMlirContext(MlirContext context);
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
// preserving the relationship that an MlirContext maps to a single
// PyMlirContext wrapper. This could be replaced in the future with an
@ -268,7 +269,7 @@ private:
// from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveModuleMap =
llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
LiveModuleMap liveModules;
// Interns all live operations associated with this context. Operations
@ -276,7 +277,7 @@ private:
// removed from this map, and while it still exists as an instance, any
// attempt to access it will raise an error.
using LiveOperationMap =
llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
LiveOperationMap liveOperations;
bool emitErrorDiagnostics = false;
@ -324,19 +325,19 @@ public:
MlirLocation get() const { return loc; }
/// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(const pybind11::object &excType,
const pybind11::object &excVal,
const pybind11::object &excTb);
static nanobind::object contextEnter(nanobind::object location);
void contextExit(const nanobind::object &excType,
const nanobind::object &excVal,
const nanobind::object &excTb);
/// Gets a capsule wrapping the void* within the MlirLocation.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyLocation from the MlirLocation wrapped by a capsule.
/// Note that PyLocation instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirLocation
/// is taken by calling this function.
static PyLocation createFromCapsule(pybind11::object capsule);
static PyLocation createFromCapsule(nanobind::object capsule);
private:
MlirLocation loc;
@ -353,8 +354,8 @@ public:
bool isValid() { return valid; }
MlirDiagnosticSeverity getSeverity();
PyLocation getLocation();
pybind11::str getMessage();
pybind11::tuple getNotes();
nanobind::str getMessage();
nanobind::tuple getNotes();
/// Materialized diagnostic information. This is safe to access outside the
/// diagnostic callback.
@ -373,7 +374,7 @@ private:
/// If notes have been materialized from the diagnostic, then this will
/// be populated with the corresponding objects (all castable to
/// PyDiagnostic).
std::optional<pybind11::tuple> materializedNotes;
std::optional<nanobind::tuple> materializedNotes;
bool valid = true;
};
@ -398,7 +399,7 @@ private:
/// is no way to attach an existing handler object).
class PyDiagnosticHandler {
public:
PyDiagnosticHandler(MlirContext context, pybind11::object callback);
PyDiagnosticHandler(MlirContext context, nanobind::object callback);
~PyDiagnosticHandler();
bool isAttached() { return registeredID.has_value(); }
@ -407,16 +408,16 @@ public:
/// Detaches the handler. Does nothing if not attached.
void detach();
pybind11::object contextEnter() { return pybind11::cast(this); }
void contextExit(const pybind11::object &excType,
const pybind11::object &excVal,
const pybind11::object &excTb) {
nanobind::object contextEnter() { return nanobind::cast(this); }
void contextExit(const nanobind::object &excType,
const nanobind::object &excVal,
const nanobind::object &excTb) {
detach();
}
private:
MlirContext context;
pybind11::object callback;
nanobind::object callback;
std::optional<MlirDiagnosticHandlerID> registeredID;
bool hadError = false;
friend class PyMlirContext;
@ -477,12 +478,12 @@ public:
/// objects of this type will be returned directly.
class PyDialect {
public:
PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {}
pybind11::object getDescriptor() { return descriptor; }
nanobind::object getDescriptor() { return descriptor; }
private:
pybind11::object descriptor;
nanobind::object descriptor;
};
/// Wrapper around an MlirDialectRegistry.
@ -505,8 +506,8 @@ public:
operator MlirDialectRegistry() const { return registry; }
MlirDialectRegistry get() const { return registry; }
pybind11::object getCapsule();
static PyDialectRegistry createFromCapsule(pybind11::object capsule);
nanobind::object getCapsule();
static PyDialectRegistry createFromCapsule(nanobind::object capsule);
private:
MlirDialectRegistry registry;
@ -542,26 +543,25 @@ public:
/// Gets a strong reference to this module.
PyModuleRef getRef() {
return PyModuleRef(this,
pybind11::reinterpret_borrow<pybind11::object>(handle));
return PyModuleRef(this, nanobind::borrow<nanobind::object>(handle));
}
/// Gets a capsule wrapping the void* within the MlirModule.
/// Note that the module does not (yet) provide a corresponding factory for
/// constructing from a capsule as that would require uniquing PyModule
/// instances, which is not currently done.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyModule from the MlirModule wrapped by a capsule.
/// Note that PyModule instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirModule
/// is taken by calling this function.
static pybind11::object createFromCapsule(pybind11::object capsule);
static nanobind::object createFromCapsule(nanobind::object capsule);
private:
PyModule(PyMlirContextRef contextRef, MlirModule module);
MlirModule module;
pybind11::handle handle;
nanobind::handle handle;
};
class PyAsmState;
@ -574,18 +574,18 @@ public:
/// Implements the bound 'print' method and helps with others.
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, py::object fileObject, bool binary,
bool assumeVerified, nanobind::object fileObject, bool binary,
bool skipRegions);
void print(PyAsmState &state, py::object fileObject, bool binary);
void print(PyAsmState &state, nanobind::object fileObject, bool binary);
pybind11::object getAsm(bool binary,
nanobind::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, bool skipRegions);
// Implement the bound 'writeBytecode' method.
void writeBytecode(const pybind11::object &fileObject,
void writeBytecode(const nanobind::object &fileObject,
std::optional<int64_t> bytecodeVersion);
// Implement the walk method.
@ -621,13 +621,13 @@ public:
/// it with a parentKeepAlive.
static PyOperationRef
forOperation(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());
nanobind::object parentKeepAlive = nanobind::object());
/// Creates a detached operation. The operation must not be associated with
/// any existing live operation.
static PyOperationRef
createDetached(PyMlirContextRef contextRef, MlirOperation operation,
pybind11::object parentKeepAlive = pybind11::object());
nanobind::object parentKeepAlive = nanobind::object());
/// Parses a source string (either text assembly or bytecode), creating a
/// detached operation.
@ -640,7 +640,7 @@ public:
void detachFromParent() {
mlirOperationRemoveFromParent(getOperation());
setDetached();
parentKeepAlive = pybind11::object();
parentKeepAlive = nanobind::object();
}
/// Gets the backing operation.
@ -651,12 +651,11 @@ public:
}
PyOperationRef getRef() {
return PyOperationRef(
this, pybind11::reinterpret_borrow<pybind11::object>(handle));
return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
}
bool isAttached() { return attached; }
void setAttached(const pybind11::object &parent = pybind11::object()) {
void setAttached(const nanobind::object &parent = nanobind::object()) {
assert(!attached && "operation already attached");
attached = true;
}
@ -675,24 +674,24 @@ public:
std::optional<PyOperationRef> getParentOperation();
/// Gets a capsule wrapping the void* within the MlirOperation.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyOperation from the MlirOperation wrapped by a capsule.
/// Ownership of the underlying MlirOperation is taken by calling this
/// function.
static pybind11::object createFromCapsule(pybind11::object capsule);
static nanobind::object createFromCapsule(nanobind::object capsule);
/// Creates an operation. See corresponding python docstring.
static pybind11::object
static nanobind::object
create(const std::string &name, std::optional<std::vector<PyType *>> results,
std::optional<std::vector<PyValue *>> operands,
std::optional<pybind11::dict> attributes,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation location, const pybind11::object &ip,
DefaultingPyLocation location, const nanobind::object &ip,
bool inferType);
/// Creates an OpView suitable for this operation.
pybind11::object createOpView();
nanobind::object createOpView();
/// Erases the underlying MlirOperation, removes its pointer from the
/// parent context's live operations map, and sets the valid bit false.
@ -702,23 +701,23 @@ public:
void setInvalid() { valid = false; }
/// Clones this operation.
pybind11::object clone(const pybind11::object &ip);
nanobind::object clone(const nanobind::object &ip);
private:
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
static PyOperationRef createInstance(PyMlirContextRef contextRef,
MlirOperation operation,
pybind11::object parentKeepAlive);
nanobind::object parentKeepAlive);
MlirOperation operation;
pybind11::handle handle;
nanobind::handle handle;
// Keeps the parent alive, regardless of whether it is an Operation or
// Module.
// TODO: As implemented, this facility is only sufficient for modeling the
// trivial module parent back-reference. Generalize this to also account for
// transitions from detached to attached and address TODOs in the
// ir_operation.py regarding testing corresponding lifetime guarantees.
pybind11::object parentKeepAlive;
nanobind::object parentKeepAlive;
bool attached = true;
bool valid = true;
@ -733,17 +732,17 @@ private:
/// python types.
class PyOpView : public PyOperationBase {
public:
PyOpView(const pybind11::object &operationObject);
PyOpView(const nanobind::object &operationObject);
PyOperation &getOperation() override { return operation; }
pybind11::object getOperationObject() { return operationObject; }
nanobind::object getOperationObject() { return operationObject; }
static pybind11::object buildGeneric(
const pybind11::object &cls, std::optional<pybind11::list> resultTypeList,
pybind11::list operandList, std::optional<pybind11::dict> attributes,
static nanobind::object buildGeneric(
const nanobind::object &cls, std::optional<nanobind::list> resultTypeList,
nanobind::list operandList, std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const pybind11::object &maybeIp);
const nanobind::object &maybeIp);
/// Construct an instance of a class deriving from OpView, bypassing its
/// `__init__` method. The derived class will typically define a constructor
@ -752,12 +751,12 @@ public:
///
/// The caller is responsible for verifying that `operation` is a valid
/// operation to construct `cls` with.
static pybind11::object constructDerived(const pybind11::object &cls,
const PyOperation &operation);
static nanobind::object constructDerived(const nanobind::object &cls,
const nanobind::object &operation);
private:
PyOperation &operation; // For efficient, cast-free access from C++
pybind11::object operationObject; // Holds the reference.
nanobind::object operationObject; // Holds the reference.
};
/// Wrapper around an MlirRegion.
@ -830,7 +829,7 @@ public:
void checkValid() { return parentOperation->checkValid(); }
/// Gets a capsule wrapping the void* within the MlirBlock.
pybind11::object getCapsule();
nanobind::object getCapsule();
private:
PyOperationRef parentOperation;
@ -858,10 +857,10 @@ public:
void insert(PyOperationBase &operationBase);
/// Enter and exit the context manager.
pybind11::object contextEnter();
void contextExit(const pybind11::object &excType,
const pybind11::object &excVal,
const pybind11::object &excTb);
static nanobind::object contextEnter(nanobind::object insertionPoint);
void contextExit(const nanobind::object &excType,
const nanobind::object &excVal,
const nanobind::object &excTb);
PyBlock &getBlock() { return block; }
std::optional<PyOperationRef> &getRefOperation() { return refOperation; }
@ -886,13 +885,13 @@ public:
MlirType get() const { return type; }
/// Gets a capsule wrapping the void* within the MlirType.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyType from the MlirType wrapped by a capsule.
/// Note that PyType instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirType
/// is taken by calling this function.
static PyType createFromCapsule(pybind11::object capsule);
static PyType createFromCapsule(nanobind::object capsule);
private:
MlirType type;
@ -912,10 +911,10 @@ public:
MlirTypeID get() { return typeID; }
/// Gets a capsule wrapping the void* within the MlirTypeID.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
static PyTypeID createFromCapsule(pybind11::object capsule);
static PyTypeID createFromCapsule(nanobind::object capsule);
private:
MlirTypeID typeID;
@ -932,7 +931,7 @@ public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
using ClassTy = nanobind::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirType);
using GetTypeIDFunctionTy = MlirTypeID (*)();
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
@ -945,34 +944,38 @@ public:
static MlirType castFrom(PyType &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
throw py::value_error((llvm::Twine("Cannot cast type to ") +
DerivedTy::pyClassName + " (from " + origRepr +
")")
.str());
auto origRepr =
nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
throw nanobind::value_error((llvm::Twine("Cannot cast type to ") +
DerivedTy::pyClassName + " (from " +
origRepr + ")")
.str()
.c_str());
}
return orig;
}
static void bind(pybind11::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(),
pybind11::arg("cast_from_type"));
static void bind(nanobind::module_ &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName);
cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
nanobind::arg("cast_from_type"));
cls.def_static(
"isinstance",
[](PyType &otherType) -> bool {
return DerivedTy::isaFunction(otherType);
},
pybind11::arg("other"));
cls.def_property_readonly_static(
"static_typeid", [](py::object & /*class*/) -> MlirTypeID {
nanobind::arg("other"));
cls.def_prop_ro_static(
"static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID {
if (DerivedTy::getTypeIdFunction)
return DerivedTy::getTypeIdFunction();
throw py::attribute_error(
(DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
throw nanobind::attribute_error(
(DerivedTy::pyClassName + llvm::Twine(" has no typeid."))
.str()
.c_str());
});
cls.def_property_readonly("typeid", [](PyType &self) {
return py::cast(self).attr("typeid").cast<MlirTypeID>();
cls.def_prop_ro("typeid", [](PyType &self) {
return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid"));
});
cls.def("__repr__", [](DerivedTy &self) {
PyPrintAccumulator printAccum;
@ -986,8 +989,8 @@ public:
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
pybind11::cpp_function(
[](PyType pyType) -> DerivedTy { return pyType; }));
nanobind::cast<nanobind::callable>(nanobind::cpp_function(
[](PyType pyType) -> DerivedTy { return pyType; })));
}
DerivedTy::bindDerived(cls);
@ -1008,13 +1011,13 @@ public:
MlirAttribute get() const { return attr; }
/// Gets a capsule wrapping the void* within the MlirAttribute.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
/// Note that PyAttribute instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAttribute
/// is taken by calling this function.
static PyAttribute createFromCapsule(pybind11::object capsule);
static PyAttribute createFromCapsule(nanobind::object capsule);
private:
MlirAttribute attr;
@ -1054,7 +1057,7 @@ public:
// Derived classes must define statics for:
// IsAFunctionTy isaFunction
// const char *pyClassName
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
using ClassTy = nanobind::class_<DerivedTy, BaseTy>;
using IsAFunctionTy = bool (*)(MlirAttribute);
using GetTypeIDFunctionTy = MlirTypeID (*)();
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
@ -1067,37 +1070,45 @@ public:
static MlirAttribute castFrom(PyAttribute &orig) {
if (!DerivedTy::isaFunction(orig)) {
auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
throw py::value_error((llvm::Twine("Cannot cast attribute to ") +
DerivedTy::pyClassName + " (from " + origRepr +
")")
.str());
auto origRepr =
nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") +
DerivedTy::pyClassName + " (from " +
origRepr + ")")
.str()
.c_str());
}
return orig;
}
static void bind(pybind11::module &m) {
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
pybind11::module_local());
cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(),
pybind11::arg("cast_from_attr"));
static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) {
ClassTy cls;
if (slots) {
cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots));
} else {
cls = ClassTy(m, DerivedTy::pyClassName);
}
cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
nanobind::arg("cast_from_attr"));
cls.def_static(
"isinstance",
[](PyAttribute &otherAttr) -> bool {
return DerivedTy::isaFunction(otherAttr);
},
pybind11::arg("other"));
cls.def_property_readonly(
nanobind::arg("other"));
cls.def_prop_ro(
"type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
cls.def_property_readonly_static(
"static_typeid", [](py::object & /*class*/) -> MlirTypeID {
cls.def_prop_ro_static(
"static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID {
if (DerivedTy::getTypeIdFunction)
return DerivedTy::getTypeIdFunction();
throw py::attribute_error(
(DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
throw nanobind::attribute_error(
(DerivedTy::pyClassName + llvm::Twine(" has no typeid."))
.str()
.c_str());
});
cls.def_property_readonly("typeid", [](PyAttribute &self) {
return py::cast(self).attr("typeid").cast<MlirTypeID>();
cls.def_prop_ro("typeid", [](PyAttribute &self) {
return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid"));
});
cls.def("__repr__", [](DerivedTy &self) {
PyPrintAccumulator printAccum;
@ -1112,9 +1123,10 @@ public:
if (DerivedTy::getTypeIdFunction) {
PyGlobals::get().registerTypeCaster(
DerivedTy::getTypeIdFunction(),
pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
return pyAttribute;
}));
nanobind::cast<nanobind::callable>(
nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
return pyAttribute;
})));
}
DerivedTy::bindDerived(cls);
@ -1146,13 +1158,13 @@ public:
void checkValid() { return parentOperation->checkValid(); }
/// Gets a capsule wrapping the void* within the MlirValue.
pybind11::object getCapsule();
nanobind::object getCapsule();
pybind11::object maybeDownCast();
nanobind::object maybeDownCast();
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
static PyValue createFromCapsule(pybind11::object capsule);
static PyValue createFromCapsule(nanobind::object capsule);
private:
PyOperationRef parentOperation;
@ -1169,13 +1181,13 @@ public:
MlirAffineExpr get() const { return affineExpr; }
/// Gets a capsule wrapping the void* within the MlirAffineExpr.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
/// Note that PyAffineExpr instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
/// is taken by calling this function.
static PyAffineExpr createFromCapsule(pybind11::object capsule);
static PyAffineExpr createFromCapsule(nanobind::object capsule);
PyAffineExpr add(const PyAffineExpr &other) const;
PyAffineExpr mul(const PyAffineExpr &other) const;
@ -1196,13 +1208,13 @@ public:
MlirAffineMap get() const { return affineMap; }
/// Gets a capsule wrapping the void* within the MlirAffineMap.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
/// Note that PyAffineMap instances are uniqued, so the returned object
/// may be a pre-existing object. Ownership of the underlying MlirAffineMap
/// is taken by calling this function.
static PyAffineMap createFromCapsule(pybind11::object capsule);
static PyAffineMap createFromCapsule(nanobind::object capsule);
private:
MlirAffineMap affineMap;
@ -1217,12 +1229,12 @@ public:
MlirIntegerSet get() const { return integerSet; }
/// Gets a capsule wrapping the void* within the MlirIntegerSet.
pybind11::object getCapsule();
nanobind::object getCapsule();
/// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
/// Note that PyIntegerSet instances may be uniqued, so the returned object
/// may be a pre-existing object. Integer sets are owned by the context.
static PyIntegerSet createFromCapsule(pybind11::object capsule);
static PyIntegerSet createFromCapsule(nanobind::object capsule);
private:
MlirIntegerSet integerSet;
@ -1239,7 +1251,7 @@ public:
/// Returns the symbol (opview) with the given name, throws if there is no
/// such symbol in the table.
pybind11::object dunderGetItem(const std::string &name);
nanobind::object dunderGetItem(const std::string &name);
/// Removes the given operation from the symbol table and erases it.
void erase(PyOperationBase &symbol);
@ -1269,7 +1281,7 @@ public:
/// Walks all symbol tables under and including 'from'.
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
pybind11::object callback);
nanobind::object callback);
/// Casts the bindings class into the C API structure.
operator MlirSymbolTable() { return symbolTable; }
@ -1289,16 +1301,16 @@ struct MLIRError {
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
};
void populateIRAffine(pybind11::module &m);
void populateIRAttributes(pybind11::module &m);
void populateIRCore(pybind11::module &m);
void populateIRInterfaces(pybind11::module &m);
void populateIRTypes(pybind11::module &m);
void populateIRAffine(nanobind::module_ &m);
void populateIRAttributes(nanobind::module_ &m);
void populateIRCore(nanobind::module_ &m);
void populateIRInterfaces(nanobind::module_ &m);
void populateIRTypes(nanobind::module_ &m);
} // namespace python
} // namespace mlir
namespace pybind11 {
namespace nanobind {
namespace detail {
template <>
@ -1309,6 +1321,6 @@ struct type_caster<mlir::python::DefaultingPyLocation>
: MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
} // namespace detail
} // namespace pybind11
} // namespace nanobind
#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H

View File

@ -6,19 +6,26 @@
//
//===----------------------------------------------------------------------===//
// clang-format off
#include "IRModule.h"
#include "PybindUtils.h"
#include "mlir/Bindings/Python/IRTypes.h"
// clang-format on
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>
#include <optional>
#include "IRModule.h"
#include "NanobindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Support.h"
#include <optional>
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
@ -48,7 +55,7 @@ public:
MlirType t = mlirIntegerTypeGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
py::arg("width"), py::arg("context") = py::none(),
nb::arg("width"), nb::arg("context").none() = nb::none(),
"Create a signless integer type");
c.def_static(
"get_signed",
@ -56,7 +63,7 @@ public:
MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
py::arg("width"), py::arg("context") = py::none(),
nb::arg("width"), nb::arg("context").none() = nb::none(),
"Create a signed integer type");
c.def_static(
"get_unsigned",
@ -64,25 +71,25 @@ public:
MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
py::arg("width"), py::arg("context") = py::none(),
nb::arg("width"), nb::arg("context").none() = nb::none(),
"Create an unsigned integer type");
c.def_property_readonly(
c.def_prop_ro(
"width",
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
"Returns the width of the integer type");
c.def_property_readonly(
c.def_prop_ro(
"is_signless",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsSignless(self);
},
"Returns whether this is a signless integer");
c.def_property_readonly(
c.def_prop_ro(
"is_signed",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsSigned(self);
},
"Returns whether this is a signed integer");
c.def_property_readonly(
c.def_prop_ro(
"is_unsigned",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsUnsigned(self);
@ -107,7 +114,7 @@ public:
MlirType t = mlirIndexTypeGet(context->get());
return PyIndexType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a index type.");
nb::arg("context").none() = nb::none(), "Create a index type.");
}
};
@ -118,7 +125,7 @@ public:
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_property_readonly(
c.def_prop_ro(
"width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
"Returns the width of the floating-point type");
}
@ -141,7 +148,7 @@ public:
MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
return PyFloat4E2M1FNType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float4_e2m1fn type.");
nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
}
};
@ -162,7 +169,7 @@ public:
MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
return PyFloat6E2M3FNType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float6_e2m3fn type.");
nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
}
};
@ -183,7 +190,7 @@ public:
MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
return PyFloat6E3M2FNType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float6_e3m2fn type.");
nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
}
};
@ -204,7 +211,7 @@ public:
MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
return PyFloat8E4M3FNType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
}
};
@ -224,7 +231,7 @@ public:
MlirType t = mlirFloat8E5M2TypeGet(context->get());
return PyFloat8E5M2Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e5m2 type.");
nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
}
};
@ -244,7 +251,7 @@ public:
MlirType t = mlirFloat8E4M3TypeGet(context->get());
return PyFloat8E4M3Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e4m3 type.");
nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
}
};
@ -265,7 +272,8 @@ public:
MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
return PyFloat8E4M3FNUZType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
nb::arg("context").none() = nb::none(),
"Create a float8_e4m3fnuz type.");
}
};
@ -286,7 +294,8 @@ public:
MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
return PyFloat8E4M3B11FNUZType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type.");
nb::arg("context").none() = nb::none(),
"Create a float8_e4m3b11fnuz type.");
}
};
@ -307,7 +316,8 @@ public:
MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
return PyFloat8E5M2FNUZType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
nb::arg("context").none() = nb::none(),
"Create a float8_e5m2fnuz type.");
}
};
@ -327,7 +337,7 @@ public:
MlirType t = mlirFloat8E3M4TypeGet(context->get());
return PyFloat8E3M4Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e3m4 type.");
nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
}
};
@ -348,7 +358,8 @@ public:
MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
return PyFloat8E8M0FNUType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a float8_e8m0fnu type.");
nb::arg("context").none() = nb::none(),
"Create a float8_e8m0fnu type.");
}
};
@ -368,7 +379,7 @@ public:
MlirType t = mlirBF16TypeGet(context->get());
return PyBF16Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a bf16 type.");
nb::arg("context").none() = nb::none(), "Create a bf16 type.");
}
};
@ -388,7 +399,7 @@ public:
MlirType t = mlirF16TypeGet(context->get());
return PyF16Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a f16 type.");
nb::arg("context").none() = nb::none(), "Create a f16 type.");
}
};
@ -408,7 +419,7 @@ public:
MlirType t = mlirTF32TypeGet(context->get());
return PyTF32Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a tf32 type.");
nb::arg("context").none() = nb::none(), "Create a tf32 type.");
}
};
@ -428,7 +439,7 @@ public:
MlirType t = mlirF32TypeGet(context->get());
return PyF32Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a f32 type.");
nb::arg("context").none() = nb::none(), "Create a f32 type.");
}
};
@ -448,7 +459,7 @@ public:
MlirType t = mlirF64TypeGet(context->get());
return PyF64Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a f64 type.");
nb::arg("context").none() = nb::none(), "Create a f64 type.");
}
};
@ -468,7 +479,7 @@ public:
MlirType t = mlirNoneTypeGet(context->get());
return PyNoneType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a none type.");
nb::arg("context").none() = nb::none(), "Create a none type.");
}
};
@ -490,14 +501,15 @@ public:
MlirType t = mlirComplexTypeGet(elementType);
return PyComplexType(elementType.getContext(), t);
}
throw py::value_error(
throw nb::value_error(
(Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
"' and expected floating point or integer type.")
.str());
.str()
.c_str());
},
"Create a complex type");
c.def_property_readonly(
c.def_prop_ro(
"element_type",
[](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
"Returns element type.");
@ -508,22 +520,22 @@ public:
// Shaped Type Interface - ShapedType
void mlir::PyShapedType::bindDerived(ClassTy &c) {
c.def_property_readonly(
c.def_prop_ro(
"element_type",
[](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
"Returns the element type of the shaped type.");
c.def_property_readonly(
c.def_prop_ro(
"has_rank",
[](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
"Returns whether the given shaped type is ranked.");
c.def_property_readonly(
c.def_prop_ro(
"rank",
[](PyShapedType &self) {
self.requireHasRank();
return mlirShapedTypeGetRank(self);
},
"Returns the rank of the given ranked shaped type.");
c.def_property_readonly(
c.def_prop_ro(
"has_static_shape",
[](PyShapedType &self) -> bool {
return mlirShapedTypeHasStaticShape(self);
@ -535,7 +547,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
self.requireHasRank();
return mlirShapedTypeIsDynamicDim(self, dim);
},
py::arg("dim"),
nb::arg("dim"),
"Returns whether the dim-th dimension of the given shaped type is "
"dynamic.");
c.def(
@ -544,12 +556,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
self.requireHasRank();
return mlirShapedTypeGetDimSize(self, dim);
},
py::arg("dim"),
nb::arg("dim"),
"Returns the dim-th dimension of the given ranked shaped type.");
c.def_static(
"is_dynamic_size",
[](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
py::arg("dim_size"),
nb::arg("dim_size"),
"Returns whether the given dimension size indicates a dynamic "
"dimension.");
c.def(
@ -558,10 +570,10 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
self.requireHasRank();
return mlirShapedTypeIsDynamicStrideOrOffset(val);
},
py::arg("dim_size"),
nb::arg("dim_size"),
"Returns whether the given value is used as a placeholder for dynamic "
"strides and offsets in shaped types.");
c.def_property_readonly(
c.def_prop_ro(
"shape",
[](PyShapedType &self) {
self.requireHasRank();
@ -587,7 +599,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
void mlir::PyShapedType::requireHasRank() {
if (!mlirShapedTypeHasRank(*this)) {
throw py::value_error(
throw nb::value_error(
"calling this method requires that the type has a rank.");
}
}
@ -607,15 +619,15 @@ public:
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static("get", &PyVectorType::get, py::arg("shape"),
py::arg("element_type"), py::kw_only(),
py::arg("scalable") = py::none(),
py::arg("scalable_dims") = py::none(),
py::arg("loc") = py::none(), "Create a vector type")
.def_property_readonly(
c.def_static("get", &PyVectorType::get, nb::arg("shape"),
nb::arg("element_type"), nb::kw_only(),
nb::arg("scalable").none() = nb::none(),
nb::arg("scalable_dims").none() = nb::none(),
nb::arg("loc").none() = nb::none(), "Create a vector type")
.def_prop_ro(
"scalable",
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
.def_property_readonly("scalable_dims", [](MlirType self) {
.def_prop_ro("scalable_dims", [](MlirType self) {
std::vector<bool> scalableDims;
size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
scalableDims.reserve(rank);
@ -627,11 +639,11 @@ public:
private:
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
std::optional<py::list> scalable,
std::optional<nb::list> scalable,
std::optional<std::vector<int64_t>> scalableDims,
DefaultingPyLocation loc) {
if (scalable && scalableDims) {
throw py::value_error("'scalable' and 'scalable_dims' kwargs "
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
"are mutually exclusive.");
}
@ -639,10 +651,10 @@ private:
MlirType type;
if (scalable) {
if (scalable->size() != shape.size())
throw py::value_error("Expected len(scalable) == len(shape).");
throw nb::value_error("Expected len(scalable) == len(shape).");
SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
*scalable, [](const py::handle &h) { return h.cast<bool>(); }));
*scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
scalableDimFlags.data(),
elementType);
@ -650,7 +662,7 @@ private:
SmallVector<bool> scalableDimFlags(shape.size(), false);
for (int64_t dim : *scalableDims) {
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
throw py::value_error("Scalable dimension index out of bounds.");
throw nb::value_error("Scalable dimension index out of bounds.");
scalableDimFlags[dim] = true;
}
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
@ -689,17 +701,17 @@ public:
throw MLIRError("Invalid type", errors.take());
return PyRankedTensorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
"Create a ranked tensor type");
c.def_property_readonly(
"encoding",
[](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return std::nullopt;
return encoding;
});
nb::arg("shape"), nb::arg("element_type"),
nb::arg("encoding").none() = nb::none(),
nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
c.def_prop_ro("encoding",
[](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
MlirAttribute encoding =
mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return std::nullopt;
return encoding;
});
}
};
@ -723,7 +735,7 @@ public:
throw MLIRError("Invalid type", errors.take());
return PyUnrankedTensorType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("loc") = py::none(),
nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
"Create a unranked tensor type");
}
};
@ -754,10 +766,11 @@ public:
throw MLIRError("Invalid type", errors.take());
return PyMemRefType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
py::arg("loc") = py::none(), "Create a memref type")
.def_property_readonly(
nb::arg("shape"), nb::arg("element_type"),
nb::arg("layout").none() = nb::none(),
nb::arg("memory_space").none() = nb::none(),
nb::arg("loc").none() = nb::none(), "Create a memref type")
.def_prop_ro(
"layout",
[](PyMemRefType &self) -> MlirAttribute {
return mlirMemRefTypeGetLayout(self);
@ -775,14 +788,14 @@ public:
return {strides, offset};
},
"The strides and offset of the MemRef type.")
.def_property_readonly(
.def_prop_ro(
"affine_map",
[](PyMemRefType &self) -> PyAffineMap {
MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
return PyAffineMap(self.getContext(), map);
},
"The layout of the MemRef type as an affine map.")
.def_property_readonly(
.def_prop_ro(
"memory_space",
[](PyMemRefType &self) -> std::optional<MlirAttribute> {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
@ -820,9 +833,9 @@ public:
throw MLIRError("Invalid type", errors.take());
return PyUnrankedMemRefType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("memory_space"),
py::arg("loc") = py::none(), "Create a unranked memref type")
.def_property_readonly(
nb::arg("element_type"), nb::arg("memory_space").none(),
nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
.def_prop_ro(
"memory_space",
[](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
@ -851,15 +864,15 @@ public:
elements.data());
return PyTupleType(context->getRef(), t);
},
py::arg("elements"), py::arg("context") = py::none(),
nb::arg("elements"), nb::arg("context").none() = nb::none(),
"Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) {
return mlirTupleTypeGetType(self, pos);
},
py::arg("pos"), "Returns the pos-th type in the tuple type.");
c.def_property_readonly(
nb::arg("pos"), "Returns the pos-th type in the tuple type.");
c.def_prop_ro(
"num_types",
[](PyTupleType &self) -> intptr_t {
return mlirTupleTypeGetNumTypes(self);
@ -887,13 +900,14 @@ public:
results.size(), results.data());
return PyFunctionType(context->getRef(), t);
},
py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
nb::arg("inputs"), nb::arg("results"),
nb::arg("context").none() = nb::none(),
"Gets a FunctionType from a list of input and result types");
c.def_property_readonly(
c.def_prop_ro(
"inputs",
[](PyFunctionType &self) {
MlirType t = self;
py::list types;
nb::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
types.append(mlirFunctionTypeGetInput(t, i));
@ -901,10 +915,10 @@ public:
return types;
},
"Returns the list of input types in the FunctionType.");
c.def_property_readonly(
c.def_prop_ro(
"results",
[](PyFunctionType &self) {
py::list types;
nb::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
types.append(mlirFunctionTypeGetResult(self, i));
@ -938,21 +952,21 @@ public:
toMlirStringRef(typeData));
return PyOpaqueType(context->getRef(), type);
},
py::arg("dialect_namespace"), py::arg("buffer"),
py::arg("context") = py::none(),
nb::arg("dialect_namespace"), nb::arg("buffer"),
nb::arg("context").none() = nb::none(),
"Create an unregistered (opaque) dialect type.");
c.def_property_readonly(
c.def_prop_ro(
"dialect_namespace",
[](PyOpaqueType &self) {
MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
return py::str(stringRef.data, stringRef.length);
return nb::str(stringRef.data, stringRef.length);
},
"Returns the dialect namespace for the Opaque type as a string.");
c.def_property_readonly(
c.def_prop_ro(
"data",
[](PyOpaqueType &self) {
MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
return py::str(stringRef.data, stringRef.length);
return nb::str(stringRef.data, stringRef.length);
},
"Returns the data for the Opaque type as a string.");
}
@ -960,7 +974,7 @@ public:
} // namespace
void mlir::python::populateIRTypes(py::module &m) {
void mlir::python::populateIRTypes(nb::module_ &m) {
PyIntegerType::bind(m);
PyFloatType::bind(m);
PyIndexType::bind(m);

View File

@ -6,29 +6,31 @@
//
//===----------------------------------------------------------------------===//
#include "PybindUtils.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/string.h>
#include "Globals.h"
#include "IRModule.h"
#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlir;
using namespace py::literals;
using namespace nb::literals;
using namespace mlir::python;
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
PYBIND11_MODULE(_mlir, m) {
NB_MODULE(_mlir, m) {
m.doc() = "MLIR Python Native Extension";
py::class_<PyGlobals>(m, "_Globals", py::module_local())
.def_property("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
nb::class_<PyGlobals>(m, "_Globals")
.def_prop_rw("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
.def(
"append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
@ -45,22 +47,21 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a, py::kw_only(),
"operation_name"_a, "operation_class"_a, nb::kw_only(),
"replace"_a = false,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
// it is necessary to make sure it is destroyed (and releases its python
// resources) properly.
m.attr("globals") =
py::cast(new PyGlobals, py::return_value_policy::take_ownership);
m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
// Registration decorators.
m.def(
"register_dialect",
[](py::type pyClass) {
[](nb::type_object pyClass) {
std::string dialectNamespace =
pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
return pyClass;
},
@ -68,45 +69,46 @@ PYBIND11_MODULE(_mlir, m) {
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
[](const py::type &dialectClass, bool replace) -> py::cpp_function {
return py::cpp_function(
[dialectClass, replace](py::type opClass) -> py::type {
[](const nb::type_object &dialectClass, bool replace) -> nb::object {
return nb::cpp_function(
[dialectClass,
replace](nb::type_object opClass) -> nb::type_object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
PyGlobals::get().registerOperationImpl(operationName, opClass,
replace);
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
nb::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
return opClass;
});
},
"dialect_class"_a, py::kw_only(), "replace"_a = false,
"dialect_class"_a, nb::kw_only(), "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
return py::cpp_function([mlirTypeID,
replace](py::object typeCaster) -> py::object {
[](MlirTypeID mlirTypeID, bool replace) -> nb::object {
return nb::cpp_function([mlirTypeID, replace](
nb::callable typeCaster) -> nb::object {
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
return typeCaster;
});
},
"typeid"_a, py::kw_only(), "replace"_a = false,
"typeid"_a, nb::kw_only(), "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
m.def(
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
return py::cpp_function(
[mlirTypeID, replace](py::object valueCaster) -> py::object {
[](MlirTypeID mlirTypeID, bool replace) -> nb::object {
return nb::cpp_function(
[mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
replace);
return valueCaster;
});
},
"typeid"_a, py::kw_only(), "replace"_a = false,
"typeid"_a, nb::kw_only(), "replace"_a = false,
"Register a value caster for casting MLIR values to custom user values.");
// Define and populate IR submodule.

View File

@ -1,4 +1,5 @@
//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
//-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -9,13 +10,21 @@
#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
#include <nanobind/nanobind.h>
#include "mlir-c/Support.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
template <>
struct std::iterator_traits<nanobind::detail::fast_iterator> {
using value_type = nanobind::handle;
using reference = const value_type;
using pointer = void;
using difference_type = std::ptrdiff_t;
using iterator_category = std::forward_iterator_tag;
};
namespace mlir {
namespace python {
@ -54,14 +63,14 @@ private:
} // namespace python
} // namespace mlir
namespace pybind11 {
namespace nanobind {
namespace detail {
template <typename DefaultingTy>
struct MlirDefaultingCaster {
PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription));
bool load(pybind11::handle src, bool) {
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
if (src.is_none()) {
// Note that we do want an exception to propagate from here as it will be
// the most informative.
@ -76,20 +85,20 @@ struct MlirDefaultingCaster {
// code to produce nice error messages (other than "Cannot cast...").
try {
value = DefaultingTy{
pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
return true;
} catch (std::exception &) {
return false;
}
}
static handle cast(DefaultingTy src, return_value_policy policy,
handle parent) {
return pybind11::cast(src, policy);
static handle from_cpp(DefaultingTy src, rv_policy policy,
cleanup_list *cleanup) noexcept {
return nanobind::cast(src, policy);
}
};
} // namespace detail
} // namespace pybind11
} // namespace nanobind
//------------------------------------------------------------------------------
// Conversion utilities.
@ -100,7 +109,7 @@ namespace mlir {
/// Accumulates into a python string from a method that accepts an
/// MlirStringCallback.
struct PyPrintAccumulator {
pybind11::list parts;
nanobind::list parts;
void *getUserData() { return this; }
@ -108,15 +117,15 @@ struct PyPrintAccumulator {
return [](MlirStringRef part, void *userData) {
PyPrintAccumulator *printAccum =
static_cast<PyPrintAccumulator *>(userData);
pybind11::str pyPart(part.data,
nanobind::str pyPart(part.data,
part.length); // Decodes as UTF-8 by default.
printAccum->parts.append(std::move(pyPart));
};
}
pybind11::str join() {
pybind11::str delim("", 0);
return delim.attr("join")(parts);
nanobind::str join() {
nanobind::str delim("", 0);
return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
}
};
@ -124,21 +133,21 @@ struct PyPrintAccumulator {
/// or binary.
class PyFileAccumulator {
public:
PyFileAccumulator(const pybind11::object &fileObject, bool binary)
PyFileAccumulator(const nanobind::object &fileObject, bool binary)
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
void *getUserData() { return this; }
MlirStringCallback getCallback() {
return [](MlirStringRef part, void *userData) {
pybind11::gil_scoped_acquire acquire;
nanobind::gil_scoped_acquire acquire;
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
if (accum->binary) {
// Note: Still has to copy and not avoidable with this API.
pybind11::bytes pyBytes(part.data, part.length);
nanobind::bytes pyBytes(part.data, part.length);
accum->pyWriteFunction(pyBytes);
} else {
pybind11::str pyStr(part.data,
nanobind::str pyStr(part.data,
part.length); // Decodes as UTF-8 by default.
accum->pyWriteFunction(pyStr);
}
@ -146,7 +155,7 @@ public:
}
private:
pybind11::object pyWriteFunction;
nanobind::object pyWriteFunction;
bool binary;
};
@ -163,17 +172,17 @@ struct PySinglePartStringAccumulator {
assert(!accum->invoked &&
"PySinglePartStringAccumulator called back multiple times");
accum->invoked = true;
accum->value = pybind11::str(part.data, part.length);
accum->value = nanobind::str(part.data, part.length);
};
}
pybind11::str takeValue() {
nanobind::str takeValue() {
assert(invoked && "PySinglePartStringAccumulator not called back");
return std::move(value);
}
private:
pybind11::str value;
nanobind::str value;
bool invoked = false;
};
@ -208,7 +217,7 @@ private:
template <typename Derived, typename ElementTy>
class Sliceable {
protected:
using ClassTy = pybind11::class_<Derived>;
using ClassTy = nanobind::class_<Derived>;
/// Transforms `index` into a legal value to access the underlying sequence.
/// Returns <0 on failure.
@ -237,7 +246,7 @@ protected:
/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
pybind11::object getItem(intptr_t index) {
nanobind::object getItem(intptr_t index) {
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
@ -250,20 +259,20 @@ protected:
->getRawElement(linearizeIndex(index))
.maybeDownCast();
else
return pybind11::cast(
return nanobind::cast(
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
}
/// Returns a new instance of the pseudo-container restricted to the given
/// slice. Returns a nullptr object on failure.
pybind11::object getItemSlice(PyObject *slice) {
nanobind::object getItemSlice(PyObject *slice) {
ssize_t start, stop, extraStep, sliceLength;
if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
&sliceLength) != 0) {
PyErr_SetString(PyExc_IndexError, "index out of range");
return {};
}
return pybind11::cast(static_cast<Derived *>(this)->slice(
return nanobind::cast(static_cast<Derived *>(this)->slice(
startIndex + start * step, sliceLength, step * extraStep));
}
@ -279,7 +288,7 @@ public:
// Negative indices mean we count from the end.
index = wrapIndex(index);
if (index < 0) {
throw pybind11::index_error("index out of range");
throw nanobind::index_error("index out of range");
}
return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
@ -304,39 +313,38 @@ public:
}
/// Binds the indexing and length methods in the Python class.
static void bind(pybind11::module &m) {
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
pybind11::module_local())
static void bind(nanobind::module_ &m) {
auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
.def("__add__", &Sliceable::dunderAdd);
Derived::bindDerived(clazz);
// Manually implement the sequence protocol via the C API. We do this
// because it is approx 4x faster than via pybind11, largely because that
// because it is approx 4x faster than via nanobind, largely because that
// formulation requires a C++ exception to be thrown to detect end of
// sequence.
// Since we are in a C-context, any C++ exception that happens here
// will terminate the program. There is nothing in this implementation
// that should throw in a non-terminal way, so we forgo further
// exception marshalling.
// See: https://github.com/pybind/pybind11/issues/2842
// See: https://github.com/pybind/nanobind/issues/2842
auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
"must be heap type");
heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
auto self = pybind11::cast<Derived *>(rawSelf);
auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
return self->length;
};
// sq_item is called as part of the sequence protocol for iteration,
// list construction, etc.
heap_type->as_sequence.sq_item =
+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
auto self = pybind11::cast<Derived *>(rawSelf);
auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
return self->getItem(index).release().ptr();
};
// mp_subscript is used for both slices and integer lookups.
heap_type->as_mapping.mp_subscript =
+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
auto self = pybind11::cast<Derived *>(rawSelf);
auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
if (!PyErr_Occurred()) {
// Integer indexing.

View File

@ -8,12 +8,16 @@
#include "Pass.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Pass.h"
namespace py = pybind11;
using namespace py::literals;
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
using namespace mlir::python;
@ -34,16 +38,15 @@ public:
MlirPassManager get() { return passManager; }
void release() { passManager.ptr = nullptr; }
pybind11::object getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonPassManagerToCapsule(get()));
nb::object getCapsule() {
return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
}
static pybind11::object createFromCapsule(pybind11::object capsule) {
static nb::object createFromCapsule(nb::object capsule) {
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(rawPm))
throw py::error_already_set();
return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
throw nb::python_error();
return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
}
private:
@ -53,22 +56,23 @@ private:
} // namespace
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(py::module &m) {
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
py::class_<PyPassManager>(m, "PassManager", py::module_local())
.def(py::init<>([](const std::string &anchorOp,
DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreateOnOperation(
context->get(),
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
return new PyPassManager(passManager);
}),
"anchor_op"_a = py::str("any"), "context"_a = py::none(),
"Create a new PassManager for the current (or provided) Context.")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyPassManager::getCapsule)
nb::class_<PyPassManager>(m, "PassManager")
.def(
"__init__",
[](PyPassManager &self, const std::string &anchorOp,
DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreateOnOperation(
context->get(),
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
new (&self) PyPassManager(passManager);
},
"anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
"Create a new PassManager for the current (or provided) Context.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
.def("_testing_release", &PyPassManager::release,
"Releases (leaks) the backing pass manager (testing)")
@ -101,9 +105,9 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
"print_before_all"_a = false, "print_after_all"_a = true,
"print_module_scope"_a = false, "print_after_change"_a = false,
"print_after_failure"_a = false,
"large_elements_limit"_a = py::none(), "enable_debug_info"_a = false,
"print_generic_op_form"_a = false,
"tree_printing_dir_path"_a = py::none(),
"large_elements_limit"_a.none() = nb::none(),
"enable_debug_info"_a = false, "print_generic_op_form"_a = false,
"tree_printing_dir_path"_a.none() = nb::none(),
"Enable IR printing, default as mlir-print-ir-after-all.")
.def(
"enable_verifier",
@ -121,10 +125,10 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw py::value_error(std::string(errorMsg.join()));
throw nb::value_error(errorMsg.join().c_str());
return new PyPassManager(passManager);
},
"pipeline"_a, "context"_a = py::none(),
"pipeline"_a, "context"_a.none() = nb::none(),
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
@ -137,7 +141,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw py::value_error(std::string(errorMsg.join()));
throw nb::value_error(errorMsg.join().c_str());
},
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "

View File

@ -9,12 +9,12 @@
#ifndef MLIR_BINDINGS_PYTHON_PASS_H
#define MLIR_BINDINGS_PYTHON_PASS_H
#include "PybindUtils.h"
#include "NanobindUtils.h"
namespace mlir {
namespace python {
void populatePassManagerSubmodule(pybind11::module &m);
void populatePassManagerSubmodule(nanobind::module_ &m);
} // namespace python
} // namespace mlir

View File

@ -8,14 +8,16 @@
#include "Rewrite.h"
#include <nanobind/nanobind.h>
#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Rewrite.h"
#include "mlir/Config/mlir-config.h"
namespace py = pybind11;
namespace nb = nanobind;
using namespace mlir;
using namespace py::literals;
using namespace nb::literals;
using namespace mlir::python;
namespace {
@ -54,18 +56,17 @@ public:
}
MlirFrozenRewritePatternSet get() { return set; }
pybind11::object getCapsule() {
return py::reinterpret_steal<py::object>(
nb::object getCapsule() {
return nb::steal<nb::object>(
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}
static pybind11::object createFromCapsule(pybind11::object capsule) {
static nb::object createFromCapsule(nb::object capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
throw py::error_already_set();
return py::cast(PyFrozenRewritePatternSet(rawPm),
py::return_value_policy::move);
throw nb::python_error();
return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
}
private:
@ -75,25 +76,27 @@ private:
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(py::module &m) {
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
.def(py::init<>([](MlirModule module) {
return mlirPDLPatternModuleFromModule(module);
}),
"module"_a, "Create a PDL module from the given module.")
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
[](PyPDLPatternModule &self, MlirModule module) {
new (&self)
PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
},
"module"_a, "Create a PDL module from the given module.")
.def("freeze", [](PyPDLPatternModule &self) {
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
});
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
py::module_local())
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
@ -102,7 +105,7 @@ void mlir::python::populateRewriteSubmodule(py::module &m) {
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
if (mlirLogicalResultIsFailure(status))
// FIXME: Not sure this is the right error to throw here.
throw py::value_error("pattern application failed to converge");
throw nb::value_error("pattern application failed to converge");
},
"module"_a, "set"_a,
"Applys the given patterns to the given module greedily while folding "

View File

@ -9,12 +9,12 @@
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
#define MLIR_BINDINGS_PYTHON_REWRITE_H
#include "PybindUtils.h"
#include "NanobindUtils.h"
namespace mlir {
namespace python {
void populateRewriteSubmodule(pybind11::module &m);
void populateRewriteSubmodule(nanobind::module_ &m);
} // namespace python
} // namespace mlir

View File

@ -448,6 +448,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
MODULE_NAME _mlir
ADD_TO_PARENT MLIRPythonSources.Core
ROOT_DIR "${PYTHON_SOURCE_DIR}"
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
MainModule.cpp
IRAffine.cpp
@ -463,7 +464,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
Globals.h
IRModule.h
Pass.h
PybindUtils.h
NanobindUtils.h
Rewrite.h
PRIVATE_LINK_LIBS
LLVMSupport

View File

@ -1,4 +1,4 @@
nanobind>=2.0, <3.0
nanobind>=2.4, <3.0
numpy>=1.19.5, <=2.1.2
pybind11>=2.10.0, <=2.13.6
PyYAML>=5.4.0, <=6.0.1

View File

@ -176,5 +176,6 @@ def testWalkSymbolTables():
try:
SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
except RuntimeError as e:
# CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
# CHECK: GOT EXCEPTION: Exception raised in callback:
# CHECK: AssertionError: Raised from python
print(f"GOT EXCEPTION: {e}")

View File

@ -168,9 +168,9 @@ maybe(
http_archive,
name = "nanobind",
build_file = "@llvm-raw//utils/bazel/third_party_build:nanobind.BUILD",
sha256 = "bfbfc7e5759f1669e4ddb48752b1ddc5647d1430e94614d6f8626df1d508e65a",
strip_prefix = "nanobind-2.2.0",
url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.2.0.tar.gz",
sha256 = "bb35deaed7efac5029ed1e33880a415638352f757d49207a8e6013fefb6c49a7",
strip_prefix = "nanobind-2.4.0",
url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.4.0.tar.gz",
)
load("@rules_python//python:repositories.bzl", "py_repositories", "python_register_toolchains")

View File

@ -1044,6 +1044,9 @@ cc_library(
srcs = [":MLIRBindingsPythonSourceFiles"],
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
includes = [
"lib/Bindings/Python",
],
textual_hdrs = [":MLIRBindingsPythonCoreHeaders"],
deps = [
":CAPIAsync",
@ -1051,11 +1054,11 @@ cc_library(
":CAPIIR",
":CAPIInterfaces",
":CAPITransforms",
":MLIRBindingsPythonHeadersAndDeps",
":MLIRBindingsPythonNanobindHeadersAndDeps",
":Support",
":config",
"//llvm:Support",
"@pybind11",
"@nanobind",
"@rules_python//python/cc:current_py_cc_headers",
],
)
@ -1065,17 +1068,20 @@ cc_library(
srcs = [":MLIRBindingsPythonSourceFiles"],
copts = PYBIND11_COPTS,
features = PYBIND11_FEATURES,
includes = [
"lib/Bindings/Python",
],
textual_hdrs = [":MLIRBindingsPythonCoreHeaders"],
deps = [
":CAPIAsyncHeaders",
":CAPIDebugHeaders",
":CAPIIRHeaders",
":CAPITransformsHeaders",
":MLIRBindingsPythonHeaders",
":MLIRBindingsPythonNanobindHeaders",
":Support",
":config",
"//llvm:Support",
"@pybind11",
"@nanobind",
"@rules_python//python/cc:current_py_cc_headers",
],
)
@ -1108,6 +1114,7 @@ cc_binary(
deps = [
":MLIRBindingsPythonCore",
":MLIRBindingsPythonHeadersAndDeps",
"@nanobind",
],
)