[mlir python] Port Python core code to nanobind. (#118583)
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks. To a large extent, this is a mechanical change, for instance changing `pybind11::` to `nanobind::`. Notes: * this PR needs Nanobind 2.4.0, because it needs a bug fix (https://github.com/wjakob/nanobind/pull/806) that landed in that release. * this PR does not port the in-tree dialect extension modules. They can be ported in a future PR. * I removed the py::sibling() annotations from def_static and def_class in `PybindAdapters.h`. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here. * nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API. * nanobind is pickier about casting to std::vector<bool>, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting. * nanobind distinguishes bytes (`nb::bytes`) from strings (e.g., `std::string`). This required nb::bytes overloads in a few places.
This commit is contained in:
parent
bfd05102d8
commit
41bd35b58b
@ -39,7 +39,7 @@ macro(mlir_configure_python_dev_packages)
|
||||
"extension = '${PYTHON_MODULE_EXTENSION}")
|
||||
|
||||
mlir_detect_nanobind_install()
|
||||
find_package(nanobind 2.2 CONFIG REQUIRED)
|
||||
find_package(nanobind 2.4 CONFIG REQUIRED)
|
||||
message(STATUS "Found nanobind v${nanobind_VERSION}: ${nanobind_INCLUDE_DIR}")
|
||||
message(STATUS "Python prefix = '${PYTHON_MODULE_PREFIX}', "
|
||||
"suffix = '${PYTHON_MODULE_SUFFIX}', "
|
||||
|
@ -9,7 +9,7 @@
|
||||
#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
|
||||
#define MLIR_BINDINGS_PYTHON_IRTYPES_H
|
||||
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
@ -374,9 +374,8 @@ public:
|
||||
static_assert(!std::is_member_function_pointer<Func>::value,
|
||||
"def_staticmethod(...) called with a non-static member "
|
||||
"function pointer");
|
||||
py::cpp_function cf(
|
||||
std::forward<Func>(f), py::name(name), py::scope(thisClass),
|
||||
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
|
||||
py::cpp_function cf(std::forward<Func>(f), py::name(name),
|
||||
py::scope(thisClass), extra...);
|
||||
thisClass.attr(cf.name()) = py::staticmethod(cf);
|
||||
return *this;
|
||||
}
|
||||
@ -387,9 +386,8 @@ public:
|
||||
static_assert(!std::is_member_function_pointer<Func>::value,
|
||||
"def_classmethod(...) called with a non-static member "
|
||||
"function pointer");
|
||||
py::cpp_function cf(
|
||||
std::forward<Func>(f), py::name(name), py::scope(thisClass),
|
||||
py::sibling(py::getattr(thisClass, name, py::none())), extra...);
|
||||
py::cpp_function cf(std::forward<Func>(f), py::name(name),
|
||||
py::scope(thisClass), extra...);
|
||||
thisClass.attr(cf.name()) =
|
||||
py::reinterpret_borrow<py::object>(PyClassMethod_New(cf.ptr()));
|
||||
return *this;
|
||||
|
@ -9,18 +9,17 @@
|
||||
#ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H
|
||||
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
|
||||
|
||||
#include "PybindUtils.h"
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "NanobindUtils.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
@ -57,55 +56,55 @@ public:
|
||||
/// Raises an exception if the mapping already exists and replace == false.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerAttributeBuilder(const std::string &attributeKind,
|
||||
pybind11::function pyFunc,
|
||||
nanobind::callable pyFunc,
|
||||
bool replace = false);
|
||||
|
||||
/// Adds a user-friendly type caster. Raises an exception if the mapping
|
||||
/// already exists and replace == false. This is intended to be called by
|
||||
/// implementation code.
|
||||
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
|
||||
void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster,
|
||||
bool replace = false);
|
||||
|
||||
/// Adds a user-friendly value caster. Raises an exception if the mapping
|
||||
/// already exists and replace == false. This is intended to be called by
|
||||
/// implementation code.
|
||||
void registerValueCaster(MlirTypeID mlirTypeID,
|
||||
pybind11::function valueCaster,
|
||||
nanobind::callable valueCaster,
|
||||
bool replace = false);
|
||||
|
||||
/// Adds a concrete implementation dialect class.
|
||||
/// Raises an exception if the mapping already exists.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerDialectImpl(const std::string &dialectNamespace,
|
||||
pybind11::object pyClass);
|
||||
nanobind::object pyClass);
|
||||
|
||||
/// Adds a concrete implementation operation class.
|
||||
/// Raises an exception if the mapping already exists and replace == false.
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerOperationImpl(const std::string &operationName,
|
||||
pybind11::object pyClass, bool replace = false);
|
||||
nanobind::object pyClass, bool replace = false);
|
||||
|
||||
/// Returns the custom Attribute builder for Attribute kind.
|
||||
std::optional<pybind11::function>
|
||||
std::optional<nanobind::callable>
|
||||
lookupAttributeBuilder(const std::string &attributeKind);
|
||||
|
||||
/// Returns the custom type caster for MlirTypeID mlirTypeID.
|
||||
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
|
||||
std::optional<nanobind::callable> lookupTypeCaster(MlirTypeID mlirTypeID,
|
||||
MlirDialect dialect);
|
||||
|
||||
/// Returns the custom value caster for MlirTypeID mlirTypeID.
|
||||
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
|
||||
std::optional<nanobind::callable> lookupValueCaster(MlirTypeID mlirTypeID,
|
||||
MlirDialect dialect);
|
||||
|
||||
/// Looks up a registered dialect class by namespace. Note that this may
|
||||
/// trigger loading of the defining module and can arbitrarily re-enter.
|
||||
std::optional<pybind11::object>
|
||||
std::optional<nanobind::object>
|
||||
lookupDialectClass(const std::string &dialectNamespace);
|
||||
|
||||
/// Looks up a registered operation class (deriving from OpView) by operation
|
||||
/// name. Note that this may trigger a load of the dialect, which can
|
||||
/// arbitrarily re-enter.
|
||||
std::optional<pybind11::object>
|
||||
std::optional<nanobind::object>
|
||||
lookupOperationClass(llvm::StringRef operationName);
|
||||
|
||||
private:
|
||||
@ -113,15 +112,15 @@ private:
|
||||
/// Module name prefixes to search under for dialect implementation modules.
|
||||
std::vector<std::string> dialectSearchPrefixes;
|
||||
/// Map of dialect namespace to external dialect class object.
|
||||
llvm::StringMap<pybind11::object> dialectClassMap;
|
||||
llvm::StringMap<nanobind::object> dialectClassMap;
|
||||
/// Map of full operation name to external operation class object.
|
||||
llvm::StringMap<pybind11::object> operationClassMap;
|
||||
llvm::StringMap<nanobind::object> operationClassMap;
|
||||
/// Map of attribute ODS name to custom builder.
|
||||
llvm::StringMap<pybind11::object> attributeBuilderMap;
|
||||
llvm::StringMap<nanobind::callable> attributeBuilderMap;
|
||||
/// Map of MlirTypeID to custom type caster.
|
||||
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;
|
||||
llvm::DenseMap<MlirTypeID, nanobind::callable> typeCasterMap;
|
||||
/// Map of MlirTypeID to custom value caster.
|
||||
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
|
||||
llvm::DenseMap<MlirTypeID, nanobind::callable> valueCasterMap;
|
||||
/// Set of dialect namespaces that we have attempted to import implementation
|
||||
/// modules for.
|
||||
llvm::StringSet<> loadedDialectModules;
|
||||
|
@ -6,20 +6,19 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/detail/common.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "IRModule.h"
|
||||
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "NanobindUtils.h"
|
||||
#include "mlir-c/AffineExpr.h"
|
||||
#include "mlir-c/AffineMap.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
@ -30,7 +29,7 @@
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
||||
@ -46,23 +45,23 @@ static const char kDumpDocstring[] =
|
||||
/// Throws errors in case of failure, using "action" to describe what the caller
|
||||
/// was attempting to do.
|
||||
template <typename PyType, typename CType>
|
||||
static void pyListToVector(const py::list &list,
|
||||
static void pyListToVector(const nb::list &list,
|
||||
llvm::SmallVectorImpl<CType> &result,
|
||||
StringRef action) {
|
||||
result.reserve(py::len(list));
|
||||
for (py::handle item : list) {
|
||||
result.reserve(nb::len(list));
|
||||
for (nb::handle item : list) {
|
||||
try {
|
||||
result.push_back(item.cast<PyType>());
|
||||
} catch (py::cast_error &err) {
|
||||
result.push_back(nb::cast<PyType>(item));
|
||||
} catch (nb::cast_error &err) {
|
||||
std::string msg = (llvm::Twine("Invalid expression when ") + action +
|
||||
" (" + err.what() + ")")
|
||||
.str();
|
||||
throw py::cast_error(msg);
|
||||
} catch (py::reference_cast_error &err) {
|
||||
throw std::runtime_error(msg.c_str());
|
||||
} catch (std::runtime_error &err) {
|
||||
std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
|
||||
action + " (" + err.what() + ")")
|
||||
.str();
|
||||
throw py::cast_error(msg);
|
||||
throw std::runtime_error(msg.c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -94,7 +93,7 @@ public:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
// and redefine bindDerived.
|
||||
using ClassTy = py::class_<DerivedTy, BaseTy>;
|
||||
using ClassTy = nb::class_<DerivedTy, BaseTy>;
|
||||
using IsAFunctionTy = bool (*)(MlirAffineExpr);
|
||||
|
||||
PyConcreteAffineExpr() = default;
|
||||
@ -105,24 +104,25 @@ public:
|
||||
|
||||
static MlirAffineExpr castFrom(PyAffineExpr &orig) {
|
||||
if (!DerivedTy::isaFunction(orig)) {
|
||||
auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
|
||||
throw py::value_error((Twine("Cannot cast affine expression to ") +
|
||||
auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
|
||||
throw nb::value_error((Twine("Cannot cast affine expression to ") +
|
||||
DerivedTy::pyClassName + " (from " + origRepr +
|
||||
")")
|
||||
.str());
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
return orig;
|
||||
}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
|
||||
cls.def(py::init<PyAffineExpr &>(), py::arg("expr"));
|
||||
static void bind(nb::module_ &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
cls.def(nb::init<PyAffineExpr &>(), nb::arg("expr"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyAffineExpr &otherAffineExpr) -> bool {
|
||||
return DerivedTy::isaFunction(otherAffineExpr);
|
||||
},
|
||||
py::arg("other"));
|
||||
nb::arg("other"));
|
||||
DerivedTy::bindDerived(cls);
|
||||
}
|
||||
|
||||
@ -144,9 +144,9 @@ public:
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"),
|
||||
py::arg("context") = py::none());
|
||||
c.def_property_readonly("value", [](PyAffineConstantExpr &self) {
|
||||
c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"),
|
||||
nb::arg("context").none() = nb::none());
|
||||
c.def_prop_ro("value", [](PyAffineConstantExpr &self) {
|
||||
return mlirAffineConstantExprGetValue(self);
|
||||
});
|
||||
}
|
||||
@ -164,9 +164,9 @@ public:
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static("get", &PyAffineDimExpr::get, py::arg("position"),
|
||||
py::arg("context") = py::none());
|
||||
c.def_property_readonly("position", [](PyAffineDimExpr &self) {
|
||||
c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"),
|
||||
nb::arg("context").none() = nb::none());
|
||||
c.def_prop_ro("position", [](PyAffineDimExpr &self) {
|
||||
return mlirAffineDimExprGetPosition(self);
|
||||
});
|
||||
}
|
||||
@ -184,9 +184,9 @@ public:
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"),
|
||||
py::arg("context") = py::none());
|
||||
c.def_property_readonly("position", [](PyAffineSymbolExpr &self) {
|
||||
c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"),
|
||||
nb::arg("context").none() = nb::none());
|
||||
c.def_prop_ro("position", [](PyAffineSymbolExpr &self) {
|
||||
return mlirAffineSymbolExprGetPosition(self);
|
||||
});
|
||||
}
|
||||
@ -209,8 +209,8 @@ public:
|
||||
}
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs);
|
||||
c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs);
|
||||
c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs);
|
||||
c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs);
|
||||
}
|
||||
};
|
||||
|
||||
@ -365,15 +365,14 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const {
|
||||
return mlirAffineExprEqual(affineExpr, other.affineExpr);
|
||||
}
|
||||
|
||||
py::object PyAffineExpr::getCapsule() {
|
||||
return py::reinterpret_steal<py::object>(
|
||||
mlirPythonAffineExprToCapsule(*this));
|
||||
nb::object PyAffineExpr::getCapsule() {
|
||||
return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this));
|
||||
}
|
||||
|
||||
PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) {
|
||||
PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) {
|
||||
MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr());
|
||||
if (mlirAffineExprIsNull(rawAffineExpr))
|
||||
throw py::error_already_set();
|
||||
throw nb::python_error();
|
||||
return PyAffineExpr(
|
||||
PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)),
|
||||
rawAffineExpr);
|
||||
@ -424,14 +423,14 @@ bool PyAffineMap::operator==(const PyAffineMap &other) const {
|
||||
return mlirAffineMapEqual(affineMap, other.affineMap);
|
||||
}
|
||||
|
||||
py::object PyAffineMap::getCapsule() {
|
||||
return py::reinterpret_steal<py::object>(mlirPythonAffineMapToCapsule(*this));
|
||||
nb::object PyAffineMap::getCapsule() {
|
||||
return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this));
|
||||
}
|
||||
|
||||
PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
|
||||
PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) {
|
||||
MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr());
|
||||
if (mlirAffineMapIsNull(rawAffineMap))
|
||||
throw py::error_already_set();
|
||||
throw nb::python_error();
|
||||
return PyAffineMap(
|
||||
PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)),
|
||||
rawAffineMap);
|
||||
@ -454,11 +453,10 @@ public:
|
||||
|
||||
bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint",
|
||||
py::module_local())
|
||||
.def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
|
||||
.def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
|
||||
static void bind(nb::module_ &m) {
|
||||
nb::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
|
||||
.def_prop_ro("expr", &PyIntegerSetConstraint::getExpr)
|
||||
.def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -501,27 +499,25 @@ bool PyIntegerSet::operator==(const PyIntegerSet &other) const {
|
||||
return mlirIntegerSetEqual(integerSet, other.integerSet);
|
||||
}
|
||||
|
||||
py::object PyIntegerSet::getCapsule() {
|
||||
return py::reinterpret_steal<py::object>(
|
||||
mlirPythonIntegerSetToCapsule(*this));
|
||||
nb::object PyIntegerSet::getCapsule() {
|
||||
return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this));
|
||||
}
|
||||
|
||||
PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
|
||||
PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) {
|
||||
MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
|
||||
if (mlirIntegerSetIsNull(rawIntegerSet))
|
||||
throw py::error_already_set();
|
||||
throw nb::python_error();
|
||||
return PyIntegerSet(
|
||||
PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
|
||||
rawIntegerSet);
|
||||
}
|
||||
|
||||
void mlir::python::populateIRAffine(py::module &m) {
|
||||
void mlir::python::populateIRAffine(nb::module_ &m) {
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyAffineExpr and derived classes.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyAffineExpr>(m, "AffineExpr", py::module_local())
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyAffineExpr::getCapsule)
|
||||
nb::class_<PyAffineExpr>(m, "AffineExpr")
|
||||
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule)
|
||||
.def("__add__", &PyAffineAddExpr::get)
|
||||
.def("__add__", &PyAffineAddExpr::getRHSConstant)
|
||||
@ -558,7 +554,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
.def("__eq__", [](PyAffineExpr &self,
|
||||
PyAffineExpr &other) { return self == other; })
|
||||
.def("__eq__",
|
||||
[](PyAffineExpr &self, py::object &other) { return false; })
|
||||
[](PyAffineExpr &self, nb::object &other) { return false; })
|
||||
.def("__str__",
|
||||
[](PyAffineExpr &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
@ -579,7 +575,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
[](PyAffineExpr &self) {
|
||||
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
|
||||
})
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"context",
|
||||
[](PyAffineExpr &self) { return self.getContext().getObject(); })
|
||||
.def("compose",
|
||||
@ -632,16 +628,16 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
.def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant,
|
||||
"Gets an affine expression containing the rounded-up result "
|
||||
"of dividing an expression by a constant.")
|
||||
.def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"),
|
||||
py::arg("context") = py::none(),
|
||||
.def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Gets a constant affine expression with the given value.")
|
||||
.def_static(
|
||||
"get_dim", &PyAffineDimExpr::get, py::arg("position"),
|
||||
py::arg("context") = py::none(),
|
||||
"get_dim", &PyAffineDimExpr::get, nb::arg("position"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Gets an affine expression of a dimension at the given position.")
|
||||
.def_static(
|
||||
"get_symbol", &PyAffineSymbolExpr::get, py::arg("position"),
|
||||
py::arg("context") = py::none(),
|
||||
"get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Gets an affine expression of a symbol at the given position.")
|
||||
.def(
|
||||
"dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); },
|
||||
@ -659,13 +655,12 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyAffineMap.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyAffineMap>(m, "AffineMap", py::module_local())
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyAffineMap::getCapsule)
|
||||
nb::class_<PyAffineMap>(m, "AffineMap")
|
||||
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule)
|
||||
.def("__eq__",
|
||||
[](PyAffineMap &self, PyAffineMap &other) { return self == other; })
|
||||
.def("__eq__", [](PyAffineMap &self, py::object &other) { return false; })
|
||||
.def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; })
|
||||
.def("__str__",
|
||||
[](PyAffineMap &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
@ -687,7 +682,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
|
||||
})
|
||||
.def_static("compress_unused_symbols",
|
||||
[](py::list affineMaps, DefaultingPyMlirContext context) {
|
||||
[](nb::list affineMaps, DefaultingPyMlirContext context) {
|
||||
SmallVector<MlirAffineMap> maps;
|
||||
pyListToVector<PyAffineMap, MlirAffineMap>(
|
||||
affineMaps, maps, "attempting to create an AffineMap");
|
||||
@ -704,7 +699,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
res.emplace_back(context->getRef(), m);
|
||||
return res;
|
||||
})
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"context",
|
||||
[](PyAffineMap &self) { return self.getContext().getObject(); },
|
||||
"Context that owns the Affine Map")
|
||||
@ -713,7 +708,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
kDumpDocstring)
|
||||
.def_static(
|
||||
"get",
|
||||
[](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
|
||||
[](intptr_t dimCount, intptr_t symbolCount, nb::list exprs,
|
||||
DefaultingPyMlirContext context) {
|
||||
SmallVector<MlirAffineExpr> affineExprs;
|
||||
pyListToVector<PyAffineExpr, MlirAffineExpr>(
|
||||
@ -723,8 +718,8 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
affineExprs.size(), affineExprs.data());
|
||||
return PyAffineMap(context->getRef(), map);
|
||||
},
|
||||
py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"),
|
||||
py::arg("context") = py::none(),
|
||||
nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Gets a map with the given expressions as results.")
|
||||
.def_static(
|
||||
"get_constant",
|
||||
@ -733,7 +728,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
mlirAffineMapConstantGet(context->get(), value);
|
||||
return PyAffineMap(context->getRef(), affineMap);
|
||||
},
|
||||
py::arg("value"), py::arg("context") = py::none(),
|
||||
nb::arg("value"), nb::arg("context").none() = nb::none(),
|
||||
"Gets an affine map with a single constant result")
|
||||
.def_static(
|
||||
"get_empty",
|
||||
@ -741,7 +736,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get());
|
||||
return PyAffineMap(context->getRef(), affineMap);
|
||||
},
|
||||
py::arg("context") = py::none(), "Gets an empty affine map.")
|
||||
nb::arg("context").none() = nb::none(), "Gets an empty affine map.")
|
||||
.def_static(
|
||||
"get_identity",
|
||||
[](intptr_t nDims, DefaultingPyMlirContext context) {
|
||||
@ -749,7 +744,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
mlirAffineMapMultiDimIdentityGet(context->get(), nDims);
|
||||
return PyAffineMap(context->getRef(), affineMap);
|
||||
},
|
||||
py::arg("n_dims"), py::arg("context") = py::none(),
|
||||
nb::arg("n_dims"), nb::arg("context").none() = nb::none(),
|
||||
"Gets an identity map with the given number of dimensions.")
|
||||
.def_static(
|
||||
"get_minor_identity",
|
||||
@ -759,8 +754,8 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults);
|
||||
return PyAffineMap(context->getRef(), affineMap);
|
||||
},
|
||||
py::arg("n_dims"), py::arg("n_results"),
|
||||
py::arg("context") = py::none(),
|
||||
nb::arg("n_dims"), nb::arg("n_results"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Gets a minor identity map with the given number of dimensions and "
|
||||
"results.")
|
||||
.def_static(
|
||||
@ -768,13 +763,13 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
[](std::vector<unsigned> permutation,
|
||||
DefaultingPyMlirContext context) {
|
||||
if (!isPermutation(permutation))
|
||||
throw py::cast_error("Invalid permutation when attempting to "
|
||||
"create an AffineMap");
|
||||
throw std::runtime_error("Invalid permutation when attempting to "
|
||||
"create an AffineMap");
|
||||
MlirAffineMap affineMap = mlirAffineMapPermutationGet(
|
||||
context->get(), permutation.size(), permutation.data());
|
||||
return PyAffineMap(context->getRef(), affineMap);
|
||||
},
|
||||
py::arg("permutation"), py::arg("context") = py::none(),
|
||||
nb::arg("permutation"), nb::arg("context").none() = nb::none(),
|
||||
"Gets an affine map that permutes its inputs.")
|
||||
.def(
|
||||
"get_submap",
|
||||
@ -782,33 +777,33 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
intptr_t numResults = mlirAffineMapGetNumResults(self);
|
||||
for (intptr_t pos : resultPos) {
|
||||
if (pos < 0 || pos >= numResults)
|
||||
throw py::value_error("result position out of bounds");
|
||||
throw nb::value_error("result position out of bounds");
|
||||
}
|
||||
MlirAffineMap affineMap = mlirAffineMapGetSubMap(
|
||||
self, resultPos.size(), resultPos.data());
|
||||
return PyAffineMap(self.getContext(), affineMap);
|
||||
},
|
||||
py::arg("result_positions"))
|
||||
nb::arg("result_positions"))
|
||||
.def(
|
||||
"get_major_submap",
|
||||
[](PyAffineMap &self, intptr_t nResults) {
|
||||
if (nResults >= mlirAffineMapGetNumResults(self))
|
||||
throw py::value_error("number of results out of bounds");
|
||||
throw nb::value_error("number of results out of bounds");
|
||||
MlirAffineMap affineMap =
|
||||
mlirAffineMapGetMajorSubMap(self, nResults);
|
||||
return PyAffineMap(self.getContext(), affineMap);
|
||||
},
|
||||
py::arg("n_results"))
|
||||
nb::arg("n_results"))
|
||||
.def(
|
||||
"get_minor_submap",
|
||||
[](PyAffineMap &self, intptr_t nResults) {
|
||||
if (nResults >= mlirAffineMapGetNumResults(self))
|
||||
throw py::value_error("number of results out of bounds");
|
||||
throw nb::value_error("number of results out of bounds");
|
||||
MlirAffineMap affineMap =
|
||||
mlirAffineMapGetMinorSubMap(self, nResults);
|
||||
return PyAffineMap(self.getContext(), affineMap);
|
||||
},
|
||||
py::arg("n_results"))
|
||||
nb::arg("n_results"))
|
||||
.def(
|
||||
"replace",
|
||||
[](PyAffineMap &self, PyAffineExpr &expression,
|
||||
@ -818,39 +813,37 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
self, expression, replacement, numResultDims, numResultSyms);
|
||||
return PyAffineMap(self.getContext(), affineMap);
|
||||
},
|
||||
py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"),
|
||||
py::arg("n_result_syms"))
|
||||
.def_property_readonly(
|
||||
nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"),
|
||||
nb::arg("n_result_syms"))
|
||||
.def_prop_ro(
|
||||
"is_permutation",
|
||||
[](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); })
|
||||
.def_property_readonly("is_projected_permutation",
|
||||
[](PyAffineMap &self) {
|
||||
return mlirAffineMapIsProjectedPermutation(self);
|
||||
})
|
||||
.def_property_readonly(
|
||||
.def_prop_ro("is_projected_permutation",
|
||||
[](PyAffineMap &self) {
|
||||
return mlirAffineMapIsProjectedPermutation(self);
|
||||
})
|
||||
.def_prop_ro(
|
||||
"n_dims",
|
||||
[](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); })
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"n_inputs",
|
||||
[](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); })
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"n_symbols",
|
||||
[](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); })
|
||||
.def_property_readonly("results", [](PyAffineMap &self) {
|
||||
return PyAffineMapExprList(self);
|
||||
});
|
||||
.def_prop_ro("results",
|
||||
[](PyAffineMap &self) { return PyAffineMapExprList(self); });
|
||||
PyAffineMapExprList::bind(m);
|
||||
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of PyIntegerSet.
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyIntegerSet>(m, "IntegerSet", py::module_local())
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyIntegerSet::getCapsule)
|
||||
nb::class_<PyIntegerSet>(m, "IntegerSet")
|
||||
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
|
||||
.def("__eq__", [](PyIntegerSet &self,
|
||||
PyIntegerSet &other) { return self == other; })
|
||||
.def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
|
||||
.def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; })
|
||||
.def("__str__",
|
||||
[](PyIntegerSet &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
@ -871,7 +864,7 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
[](PyIntegerSet &self) {
|
||||
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
|
||||
})
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"context",
|
||||
[](PyIntegerSet &self) { return self.getContext().getObject(); })
|
||||
.def(
|
||||
@ -879,14 +872,14 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
kDumpDocstring)
|
||||
.def_static(
|
||||
"get",
|
||||
[](intptr_t numDims, intptr_t numSymbols, py::list exprs,
|
||||
[](intptr_t numDims, intptr_t numSymbols, nb::list exprs,
|
||||
std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
|
||||
if (exprs.size() != eqFlags.size())
|
||||
throw py::value_error(
|
||||
throw nb::value_error(
|
||||
"Expected the number of constraints to match "
|
||||
"that of equality flags");
|
||||
if (exprs.empty())
|
||||
throw py::value_error("Expected non-empty list of constraints");
|
||||
if (exprs.size() == 0)
|
||||
throw nb::value_error("Expected non-empty list of constraints");
|
||||
|
||||
// Copy over to a SmallVector because std::vector has a
|
||||
// specialization for booleans that packs data and does not
|
||||
@ -901,8 +894,8 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
affineExprs.data(), flags.data());
|
||||
return PyIntegerSet(context->getRef(), set);
|
||||
},
|
||||
py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
|
||||
py::arg("eq_flags"), py::arg("context") = py::none())
|
||||
nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"),
|
||||
nb::arg("eq_flags"), nb::arg("context").none() = nb::none())
|
||||
.def_static(
|
||||
"get_empty",
|
||||
[](intptr_t numDims, intptr_t numSymbols,
|
||||
@ -911,20 +904,20 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
|
||||
return PyIntegerSet(context->getRef(), set);
|
||||
},
|
||||
py::arg("num_dims"), py::arg("num_symbols"),
|
||||
py::arg("context") = py::none())
|
||||
nb::arg("num_dims"), nb::arg("num_symbols"),
|
||||
nb::arg("context").none() = nb::none())
|
||||
.def(
|
||||
"get_replaced",
|
||||
[](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
|
||||
[](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs,
|
||||
intptr_t numResultDims, intptr_t numResultSymbols) {
|
||||
if (static_cast<intptr_t>(dimExprs.size()) !=
|
||||
mlirIntegerSetGetNumDims(self))
|
||||
throw py::value_error(
|
||||
throw nb::value_error(
|
||||
"Expected the number of dimension replacement expressions "
|
||||
"to match that of dimensions");
|
||||
if (static_cast<intptr_t>(symbolExprs.size()) !=
|
||||
mlirIntegerSetGetNumSymbols(self))
|
||||
throw py::value_error(
|
||||
throw nb::value_error(
|
||||
"Expected the number of symbol replacement expressions "
|
||||
"to match that of symbols");
|
||||
|
||||
@ -940,30 +933,30 @@ void mlir::python::populateIRAffine(py::module &m) {
|
||||
numResultDims, numResultSymbols);
|
||||
return PyIntegerSet(self.getContext(), set);
|
||||
},
|
||||
py::arg("dim_exprs"), py::arg("symbol_exprs"),
|
||||
py::arg("num_result_dims"), py::arg("num_result_symbols"))
|
||||
.def_property_readonly("is_canonical_empty",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetIsCanonicalEmpty(self);
|
||||
})
|
||||
.def_property_readonly(
|
||||
nb::arg("dim_exprs"), nb::arg("symbol_exprs"),
|
||||
nb::arg("num_result_dims"), nb::arg("num_result_symbols"))
|
||||
.def_prop_ro("is_canonical_empty",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetIsCanonicalEmpty(self);
|
||||
})
|
||||
.def_prop_ro(
|
||||
"n_dims",
|
||||
[](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"n_symbols",
|
||||
[](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"n_inputs",
|
||||
[](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
|
||||
.def_property_readonly("n_equalities",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetGetNumEqualities(self);
|
||||
})
|
||||
.def_property_readonly("n_inequalities",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetGetNumInequalities(self);
|
||||
})
|
||||
.def_property_readonly("constraints", [](PyIntegerSet &self) {
|
||||
.def_prop_ro("n_equalities",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetGetNumEqualities(self);
|
||||
})
|
||||
.def_prop_ro("n_inequalities",
|
||||
[](PyIntegerSet &self) {
|
||||
return mlirIntegerSetGetNumInequalities(self);
|
||||
})
|
||||
.def_prop_ro("constraints", [](PyIntegerSet &self) {
|
||||
return PyIntegerSetConstraintList(self);
|
||||
});
|
||||
PyIntegerSetConstraint::bind(m);
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -6,12 +6,12 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <pybind11/cast.h>
|
||||
#include <pybind11/detail/common.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -24,7 +24,7 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
@ -53,10 +53,10 @@ namespace {
|
||||
|
||||
/// Takes in an optional ist of operands and converts them into a SmallVector
|
||||
/// of MlirVlaues. Returns an empty SmallVector if the list is empty.
|
||||
llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
|
||||
llvm::SmallVector<MlirValue> wrapOperands(std::optional<nb::list> operandList) {
|
||||
llvm::SmallVector<MlirValue> mlirOperands;
|
||||
|
||||
if (!operandList || operandList->empty()) {
|
||||
if (!operandList || operandList->size() == 0) {
|
||||
return mlirOperands;
|
||||
}
|
||||
|
||||
@ -68,40 +68,42 @@ llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
|
||||
|
||||
PyValue *val;
|
||||
try {
|
||||
val = py::cast<PyValue *>(it.value());
|
||||
val = nb::cast<PyValue *>(it.value());
|
||||
if (!val)
|
||||
throw py::cast_error();
|
||||
throw nb::cast_error();
|
||||
mlirOperands.push_back(val->get());
|
||||
continue;
|
||||
} catch (py::cast_error &err) {
|
||||
} catch (nb::cast_error &err) {
|
||||
// Intentionally unhandled to try sequence below first.
|
||||
(void)err;
|
||||
}
|
||||
|
||||
try {
|
||||
auto vals = py::cast<py::sequence>(it.value());
|
||||
for (py::object v : vals) {
|
||||
auto vals = nb::cast<nb::sequence>(it.value());
|
||||
for (nb::handle v : vals) {
|
||||
try {
|
||||
val = py::cast<PyValue *>(v);
|
||||
val = nb::cast<PyValue *>(v);
|
||||
if (!val)
|
||||
throw py::cast_error();
|
||||
throw nb::cast_error();
|
||||
mlirOperands.push_back(val->get());
|
||||
} catch (py::cast_error &err) {
|
||||
throw py::value_error(
|
||||
} catch (nb::cast_error &err) {
|
||||
throw nb::value_error(
|
||||
(llvm::Twine("Operand ") + llvm::Twine(it.index()) +
|
||||
" must be a Value or Sequence of Values (" + err.what() + ")")
|
||||
.str());
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
}
|
||||
continue;
|
||||
} catch (py::cast_error &err) {
|
||||
throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
|
||||
} catch (nb::cast_error &err) {
|
||||
throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
|
||||
" must be a Value or Sequence of Values (" +
|
||||
err.what() + ")")
|
||||
.str());
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
|
||||
throw py::cast_error();
|
||||
throw nb::cast_error();
|
||||
}
|
||||
|
||||
return mlirOperands;
|
||||
@ -144,24 +146,24 @@ wrapRegions(std::optional<std::vector<PyRegion>> regions) {
|
||||
template <typename ConcreteIface>
|
||||
class PyConcreteOpInterface {
|
||||
protected:
|
||||
using ClassTy = py::class_<ConcreteIface>;
|
||||
using ClassTy = nb::class_<ConcreteIface>;
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
|
||||
public:
|
||||
/// Constructs an interface instance from an object that is either an
|
||||
/// operation or a subclass of OpView. In the latter case, only the static
|
||||
/// methods of the interface are accessible to the caller.
|
||||
PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
|
||||
PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context)
|
||||
: obj(std::move(object)) {
|
||||
try {
|
||||
operation = &py::cast<PyOperation &>(obj);
|
||||
} catch (py::cast_error &) {
|
||||
operation = &nb::cast<PyOperation &>(obj);
|
||||
} catch (nb::cast_error &) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
try {
|
||||
operation = &py::cast<PyOpView &>(obj).getOperation();
|
||||
} catch (py::cast_error &) {
|
||||
operation = &nb::cast<PyOpView &>(obj).getOperation();
|
||||
} catch (nb::cast_error &) {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
@ -169,7 +171,7 @@ public:
|
||||
if (!mlirOperationImplementsInterface(*operation,
|
||||
ConcreteIface::getInterfaceID())) {
|
||||
std::string msg = "the operation does not implement ";
|
||||
throw py::value_error(msg + ConcreteIface::pyClassName);
|
||||
throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
|
||||
}
|
||||
|
||||
MlirIdentifier identifier = mlirOperationGetName(*operation);
|
||||
@ -177,9 +179,9 @@ public:
|
||||
opName = std::string(stringRef.data, stringRef.length);
|
||||
} else {
|
||||
try {
|
||||
opName = obj.attr("OPERATION_NAME").template cast<std::string>();
|
||||
} catch (py::cast_error &) {
|
||||
throw py::type_error(
|
||||
opName = nb::cast<std::string>(obj.attr("OPERATION_NAME"));
|
||||
} catch (nb::cast_error &) {
|
||||
throw nb::type_error(
|
||||
"Op interface does not refer to an operation or OpView class");
|
||||
}
|
||||
|
||||
@ -187,22 +189,19 @@ public:
|
||||
mlirStringRefCreate(opName.data(), opName.length()),
|
||||
context.resolve().get(), ConcreteIface::getInterfaceID())) {
|
||||
std::string msg = "the operation does not implement ";
|
||||
throw py::value_error(msg + ConcreteIface::pyClassName);
|
||||
throw nb::value_error((msg + ConcreteIface::pyClassName).c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates the Python bindings for this class in the given module.
|
||||
static void bind(py::module &m) {
|
||||
py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
|
||||
py::module_local());
|
||||
cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
|
||||
py::arg("context") = py::none(), constructorDoc)
|
||||
.def_property_readonly("operation",
|
||||
&PyConcreteOpInterface::getOperationObject,
|
||||
operationDoc)
|
||||
.def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
|
||||
opviewDoc);
|
||||
static void bind(nb::module_ &m) {
|
||||
nb::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName);
|
||||
cls.def(nb::init<nb::object, DefaultingPyMlirContext>(), nb::arg("object"),
|
||||
nb::arg("context").none() = nb::none(), constructorDoc)
|
||||
.def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject,
|
||||
operationDoc)
|
||||
.def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc);
|
||||
ConcreteIface::bindDerived(cls);
|
||||
}
|
||||
|
||||
@ -216,9 +215,9 @@ public:
|
||||
/// Returns the operation instance from which this object was constructed.
|
||||
/// Throws a type error if this object was constructed from a subclass of
|
||||
/// OpView.
|
||||
py::object getOperationObject() {
|
||||
nb::object getOperationObject() {
|
||||
if (operation == nullptr) {
|
||||
throw py::type_error("Cannot get an operation from a static interface");
|
||||
throw nb::type_error("Cannot get an operation from a static interface");
|
||||
}
|
||||
|
||||
return operation->getRef().releaseObject();
|
||||
@ -227,9 +226,9 @@ public:
|
||||
/// Returns the opview of the operation instance from which this object was
|
||||
/// constructed. Throws a type error if this object was constructed form a
|
||||
/// subclass of OpView.
|
||||
py::object getOpView() {
|
||||
nb::object getOpView() {
|
||||
if (operation == nullptr) {
|
||||
throw py::type_error("Cannot get an opview from a static interface");
|
||||
throw nb::type_error("Cannot get an opview from a static interface");
|
||||
}
|
||||
|
||||
return operation->createOpView();
|
||||
@ -242,7 +241,7 @@ public:
|
||||
private:
|
||||
PyOperation *operation = nullptr;
|
||||
std::string opName;
|
||||
py::object obj;
|
||||
nb::object obj;
|
||||
};
|
||||
|
||||
/// Python wrapper for InferTypeOpInterface. This interface has only static
|
||||
@ -276,7 +275,7 @@ public:
|
||||
/// Given the arguments required to build an operation, attempts to infer its
|
||||
/// return types. Throws value_error on failure.
|
||||
std::vector<PyType>
|
||||
inferReturnTypes(std::optional<py::list> operandList,
|
||||
inferReturnTypes(std::optional<nb::list> operandList,
|
||||
std::optional<PyAttribute> attributes, void *properties,
|
||||
std::optional<std::vector<PyRegion>> regions,
|
||||
DefaultingPyMlirContext context,
|
||||
@ -299,7 +298,7 @@ public:
|
||||
mlirRegions.data(), &appendResultsCallback, &data);
|
||||
|
||||
if (mlirLogicalResultIsFailure(result)) {
|
||||
throw py::value_error("Failed to infer result types");
|
||||
throw nb::value_error("Failed to infer result types");
|
||||
}
|
||||
|
||||
return inferredTypes;
|
||||
@ -307,11 +306,12 @@ public:
|
||||
|
||||
static void bindDerived(ClassTy &cls) {
|
||||
cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
|
||||
py::arg("operands") = py::none(),
|
||||
py::arg("attributes") = py::none(),
|
||||
py::arg("properties") = py::none(), py::arg("regions") = py::none(),
|
||||
py::arg("context") = py::none(), py::arg("loc") = py::none(),
|
||||
inferReturnTypesDoc);
|
||||
nb::arg("operands").none() = nb::none(),
|
||||
nb::arg("attributes").none() = nb::none(),
|
||||
nb::arg("properties").none() = nb::none(),
|
||||
nb::arg("regions").none() = nb::none(),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
nb::arg("loc").none() = nb::none(), inferReturnTypesDoc);
|
||||
}
|
||||
};
|
||||
|
||||
@ -319,9 +319,9 @@ public:
|
||||
class PyShapedTypeComponents {
|
||||
public:
|
||||
PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
|
||||
PyShapedTypeComponents(py::list shape, MlirType elementType)
|
||||
PyShapedTypeComponents(nb::list shape, MlirType elementType)
|
||||
: shape(std::move(shape)), elementType(elementType), ranked(true) {}
|
||||
PyShapedTypeComponents(py::list shape, MlirType elementType,
|
||||
PyShapedTypeComponents(nb::list shape, MlirType elementType,
|
||||
MlirAttribute attribute)
|
||||
: shape(std::move(shape)), elementType(elementType), attribute(attribute),
|
||||
ranked(true) {}
|
||||
@ -330,10 +330,9 @@ public:
|
||||
: shape(other.shape), elementType(other.elementType),
|
||||
attribute(other.attribute), ranked(other.ranked) {}
|
||||
|
||||
static void bind(py::module &m) {
|
||||
py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
|
||||
py::module_local())
|
||||
.def_property_readonly(
|
||||
static void bind(nb::module_ &m) {
|
||||
nb::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents")
|
||||
.def_prop_ro(
|
||||
"element_type",
|
||||
[](PyShapedTypeComponents &self) { return self.elementType; },
|
||||
"Returns the element type of the shaped type components.")
|
||||
@ -342,57 +341,57 @@ public:
|
||||
[](PyType &elementType) {
|
||||
return PyShapedTypeComponents(elementType);
|
||||
},
|
||||
py::arg("element_type"),
|
||||
nb::arg("element_type"),
|
||||
"Create an shaped type components object with only the element "
|
||||
"type.")
|
||||
.def_static(
|
||||
"get",
|
||||
[](py::list shape, PyType &elementType) {
|
||||
[](nb::list shape, PyType &elementType) {
|
||||
return PyShapedTypeComponents(std::move(shape), elementType);
|
||||
},
|
||||
py::arg("shape"), py::arg("element_type"),
|
||||
nb::arg("shape"), nb::arg("element_type"),
|
||||
"Create a ranked shaped type components object.")
|
||||
.def_static(
|
||||
"get",
|
||||
[](py::list shape, PyType &elementType, PyAttribute &attribute) {
|
||||
[](nb::list shape, PyType &elementType, PyAttribute &attribute) {
|
||||
return PyShapedTypeComponents(std::move(shape), elementType,
|
||||
attribute);
|
||||
},
|
||||
py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
|
||||
nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"),
|
||||
"Create a ranked shaped type components object with attribute.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"has_rank",
|
||||
[](PyShapedTypeComponents &self) -> bool { return self.ranked; },
|
||||
"Returns whether the given shaped type component is ranked.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"rank",
|
||||
[](PyShapedTypeComponents &self) -> py::object {
|
||||
[](PyShapedTypeComponents &self) -> nb::object {
|
||||
if (!self.ranked) {
|
||||
return py::none();
|
||||
return nb::none();
|
||||
}
|
||||
return py::int_(self.shape.size());
|
||||
return nb::int_(self.shape.size());
|
||||
},
|
||||
"Returns the rank of the given ranked shaped type components. If "
|
||||
"the shaped type components does not have a rank, None is "
|
||||
"returned.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"shape",
|
||||
[](PyShapedTypeComponents &self) -> py::object {
|
||||
[](PyShapedTypeComponents &self) -> nb::object {
|
||||
if (!self.ranked) {
|
||||
return py::none();
|
||||
return nb::none();
|
||||
}
|
||||
return py::list(self.shape);
|
||||
return nb::list(self.shape);
|
||||
},
|
||||
"Returns the shape of the ranked shaped type components as a list "
|
||||
"of integers. Returns none if the shaped type component does not "
|
||||
"have a rank.");
|
||||
}
|
||||
|
||||
pybind11::object getCapsule();
|
||||
static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
|
||||
nb::object getCapsule();
|
||||
static PyShapedTypeComponents createFromCapsule(nb::object capsule);
|
||||
|
||||
private:
|
||||
py::list shape;
|
||||
nb::list shape;
|
||||
MlirType elementType;
|
||||
MlirAttribute attribute;
|
||||
bool ranked{false};
|
||||
@ -424,7 +423,7 @@ public:
|
||||
if (!hasRank) {
|
||||
data->inferredShapedTypeComponents.emplace_back(elementType);
|
||||
} else {
|
||||
py::list shapeList;
|
||||
nb::list shapeList;
|
||||
for (intptr_t i = 0; i < rank; ++i) {
|
||||
shapeList.append(shape[i]);
|
||||
}
|
||||
@ -436,7 +435,7 @@ public:
|
||||
/// Given the arguments required to build an operation, attempts to infer the
|
||||
/// shaped type components. Throws value_error on failure.
|
||||
std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
|
||||
std::optional<py::list> operandList,
|
||||
std::optional<nb::list> operandList,
|
||||
std::optional<PyAttribute> attributes, void *properties,
|
||||
std::optional<std::vector<PyRegion>> regions,
|
||||
DefaultingPyMlirContext context, DefaultingPyLocation location) {
|
||||
@ -458,7 +457,7 @@ public:
|
||||
mlirRegions.data(), &appendResultsCallback, &data);
|
||||
|
||||
if (mlirLogicalResultIsFailure(result)) {
|
||||
throw py::value_error("Failed to infer result shape type components");
|
||||
throw nb::value_error("Failed to infer result shape type components");
|
||||
}
|
||||
|
||||
return inferredShapedTypeComponents;
|
||||
@ -467,14 +466,16 @@ public:
|
||||
static void bindDerived(ClassTy &cls) {
|
||||
cls.def("inferReturnTypeComponents",
|
||||
&PyInferShapedTypeOpInterface::inferReturnTypeComponents,
|
||||
py::arg("operands") = py::none(),
|
||||
py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
|
||||
py::arg("properties") = py::none(), py::arg("context") = py::none(),
|
||||
py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
|
||||
nb::arg("operands").none() = nb::none(),
|
||||
nb::arg("attributes").none() = nb::none(),
|
||||
nb::arg("regions").none() = nb::none(),
|
||||
nb::arg("properties").none() = nb::none(),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc);
|
||||
}
|
||||
};
|
||||
|
||||
void populateIRInterfaces(py::module &m) {
|
||||
void populateIRInterfaces(nb::module_ &m) {
|
||||
PyInferTypeOpInterface::bind(m);
|
||||
PyShapedTypeComponents::bind(m);
|
||||
PyInferShapedTypeOpInterface::bind(m);
|
||||
|
@ -7,16 +7,19 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "IRModule.h"
|
||||
#include "Globals.h"
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Support.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace py = pybind11;
|
||||
#include "Globals.h"
|
||||
#include "NanobindUtils.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
||||
@ -41,14 +44,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
|
||||
return true;
|
||||
// Since re-entrancy is possible, make a copy of the search prefixes.
|
||||
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
|
||||
py::object loaded = py::none();
|
||||
nb::object loaded = nb::none();
|
||||
for (std::string moduleName : localSearchPrefixes) {
|
||||
moduleName.push_back('.');
|
||||
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
|
||||
|
||||
try {
|
||||
loaded = py::module::import(moduleName.c_str());
|
||||
} catch (py::error_already_set &e) {
|
||||
loaded = nb::module_::import_(moduleName.c_str());
|
||||
} catch (nb::python_error &e) {
|
||||
if (e.matches(PyExc_ModuleNotFoundError)) {
|
||||
continue;
|
||||
}
|
||||
@ -66,41 +69,39 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
|
||||
}
|
||||
|
||||
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
|
||||
py::function pyFunc, bool replace) {
|
||||
py::object &found = attributeBuilderMap[attributeKind];
|
||||
nb::callable pyFunc, bool replace) {
|
||||
nb::object &found = attributeBuilderMap[attributeKind];
|
||||
if (found && !replace) {
|
||||
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
|
||||
attributeKind +
|
||||
"' is already registered with func: " +
|
||||
py::str(found).operator std::string())
|
||||
nb::cast<std::string>(nb::str(found)))
|
||||
.str());
|
||||
}
|
||||
found = std::move(pyFunc);
|
||||
}
|
||||
|
||||
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
|
||||
pybind11::function typeCaster,
|
||||
bool replace) {
|
||||
pybind11::object &found = typeCasterMap[mlirTypeID];
|
||||
nb::callable typeCaster, bool replace) {
|
||||
nb::object &found = typeCasterMap[mlirTypeID];
|
||||
if (found && !replace)
|
||||
throw std::runtime_error("Type caster is already registered with caster: " +
|
||||
py::str(found).operator std::string());
|
||||
nb::cast<std::string>(nb::str(found)));
|
||||
found = std::move(typeCaster);
|
||||
}
|
||||
|
||||
void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
|
||||
pybind11::function valueCaster,
|
||||
bool replace) {
|
||||
pybind11::object &found = valueCasterMap[mlirTypeID];
|
||||
nb::callable valueCaster, bool replace) {
|
||||
nb::object &found = valueCasterMap[mlirTypeID];
|
||||
if (found && !replace)
|
||||
throw std::runtime_error("Value caster is already registered: " +
|
||||
py::repr(found).cast<std::string>());
|
||||
nb::cast<std::string>(nb::repr(found)));
|
||||
found = std::move(valueCaster);
|
||||
}
|
||||
|
||||
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
||||
py::object pyClass) {
|
||||
py::object &found = dialectClassMap[dialectNamespace];
|
||||
nb::object pyClass) {
|
||||
nb::object &found = dialectClassMap[dialectNamespace];
|
||||
if (found) {
|
||||
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
|
||||
dialectNamespace + "' is already registered.")
|
||||
@ -110,8 +111,8 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
||||
}
|
||||
|
||||
void PyGlobals::registerOperationImpl(const std::string &operationName,
|
||||
py::object pyClass, bool replace) {
|
||||
py::object &found = operationClassMap[operationName];
|
||||
nb::object pyClass, bool replace) {
|
||||
nb::object &found = operationClassMap[operationName];
|
||||
if (found && !replace) {
|
||||
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
|
||||
"' is already registered.")
|
||||
@ -120,7 +121,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName,
|
||||
found = std::move(pyClass);
|
||||
}
|
||||
|
||||
std::optional<py::function>
|
||||
std::optional<nb::callable>
|
||||
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
|
||||
const auto foundIt = attributeBuilderMap.find(attributeKind);
|
||||
if (foundIt != attributeBuilderMap.end()) {
|
||||
@ -130,7 +131,7 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
|
||||
std::optional<nb::callable> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
|
||||
MlirDialect dialect) {
|
||||
// Try to load dialect module.
|
||||
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
|
||||
@ -142,7 +143,7 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
|
||||
std::optional<nb::callable> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
|
||||
MlirDialect dialect) {
|
||||
// Try to load dialect module.
|
||||
(void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
|
||||
@ -154,7 +155,7 @@ std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<py::object>
|
||||
std::optional<nb::object>
|
||||
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
|
||||
// Make sure dialect module is loaded.
|
||||
if (!loadDialectModule(dialectNamespace))
|
||||
@ -168,7 +169,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<pybind11::object>
|
||||
std::optional<nb::object>
|
||||
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
|
||||
// Make sure dialect module is loaded.
|
||||
auto split = operationName.split('.');
|
||||
|
@ -10,20 +10,22 @@
|
||||
#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
|
||||
#define MLIR_BINDINGS_PYTHON_IRMODULES_H
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Globals.h"
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "NanobindUtils.h"
|
||||
#include "mlir-c/AffineExpr.h"
|
||||
#include "mlir-c/AffineMap.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/IntegerSet.h"
|
||||
#include "mlir-c/Transforms.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
namespace mlir {
|
||||
@ -49,7 +51,7 @@ class PyValue;
|
||||
template <typename T>
|
||||
class PyObjectRef {
|
||||
public:
|
||||
PyObjectRef(T *referrent, pybind11::object object)
|
||||
PyObjectRef(T *referrent, nanobind::object object)
|
||||
: referrent(referrent), object(std::move(object)) {
|
||||
assert(this->referrent &&
|
||||
"cannot construct PyObjectRef with null referrent");
|
||||
@ -67,13 +69,13 @@ public:
|
||||
int getRefCount() {
|
||||
if (!object)
|
||||
return 0;
|
||||
return object.ref_count();
|
||||
return Py_REFCNT(object.ptr());
|
||||
}
|
||||
|
||||
/// Releases the object held by this instance, returning it.
|
||||
/// This is the proper thing to return from a function that wants to return
|
||||
/// the reference. Note that this does not work from initializers.
|
||||
pybind11::object releaseObject() {
|
||||
nanobind::object releaseObject() {
|
||||
assert(referrent && object);
|
||||
referrent = nullptr;
|
||||
auto stolen = std::move(object);
|
||||
@ -85,7 +87,7 @@ public:
|
||||
assert(referrent && object);
|
||||
return referrent;
|
||||
}
|
||||
pybind11::object getObject() {
|
||||
nanobind::object getObject() {
|
||||
assert(referrent && object);
|
||||
return object;
|
||||
}
|
||||
@ -93,7 +95,7 @@ public:
|
||||
|
||||
private:
|
||||
T *referrent;
|
||||
pybind11::object object;
|
||||
nanobind::object object;
|
||||
};
|
||||
|
||||
/// Tracks an entry in the thread context stack. New entries are pushed onto
|
||||
@ -112,9 +114,9 @@ public:
|
||||
Location,
|
||||
};
|
||||
|
||||
PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
|
||||
pybind11::object insertionPoint,
|
||||
pybind11::object location)
|
||||
PyThreadContextEntry(FrameKind frameKind, nanobind::object context,
|
||||
nanobind::object insertionPoint,
|
||||
nanobind::object location)
|
||||
: context(std::move(context)), insertionPoint(std::move(insertionPoint)),
|
||||
location(std::move(location)), frameKind(frameKind) {}
|
||||
|
||||
@ -134,26 +136,26 @@ public:
|
||||
|
||||
/// Stack management.
|
||||
static PyThreadContextEntry *getTopOfStack();
|
||||
static pybind11::object pushContext(PyMlirContext &context);
|
||||
static nanobind::object pushContext(nanobind::object context);
|
||||
static void popContext(PyMlirContext &context);
|
||||
static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
|
||||
static nanobind::object pushInsertionPoint(nanobind::object insertionPoint);
|
||||
static void popInsertionPoint(PyInsertionPoint &insertionPoint);
|
||||
static pybind11::object pushLocation(PyLocation &location);
|
||||
static nanobind::object pushLocation(nanobind::object location);
|
||||
static void popLocation(PyLocation &location);
|
||||
|
||||
/// Gets the thread local stack.
|
||||
static std::vector<PyThreadContextEntry> &getStack();
|
||||
|
||||
private:
|
||||
static void push(FrameKind frameKind, pybind11::object context,
|
||||
pybind11::object insertionPoint, pybind11::object location);
|
||||
static void push(FrameKind frameKind, nanobind::object context,
|
||||
nanobind::object insertionPoint, nanobind::object location);
|
||||
|
||||
/// An object reference to the PyContext.
|
||||
pybind11::object context;
|
||||
nanobind::object context;
|
||||
/// An object reference to the current insertion point.
|
||||
pybind11::object insertionPoint;
|
||||
nanobind::object insertionPoint;
|
||||
/// An object reference to the current location.
|
||||
pybind11::object location;
|
||||
nanobind::object location;
|
||||
// The kind of push that was performed.
|
||||
FrameKind frameKind;
|
||||
};
|
||||
@ -163,14 +165,15 @@ using PyMlirContextRef = PyObjectRef<PyMlirContext>;
|
||||
class PyMlirContext {
|
||||
public:
|
||||
PyMlirContext() = delete;
|
||||
PyMlirContext(MlirContext context);
|
||||
PyMlirContext(const PyMlirContext &) = delete;
|
||||
PyMlirContext(PyMlirContext &&) = delete;
|
||||
|
||||
/// For the case of a python __init__ (py::init) method, pybind11 is quite
|
||||
/// strict about needing to return a pointer that is not yet associated to
|
||||
/// an py::object. Since the forContext() method acts like a pool, possibly
|
||||
/// returning a recycled context, it does not satisfy this need. The usual
|
||||
/// way in python to accomplish such a thing is to override __new__, but
|
||||
/// For the case of a python __init__ (nanobind::init) method, pybind11 is
|
||||
/// quite strict about needing to return a pointer that is not yet associated
|
||||
/// to an nanobind::object. Since the forContext() method acts like a pool,
|
||||
/// possibly returning a recycled context, it does not satisfy this need. The
|
||||
/// usual way in python to accomplish such a thing is to override __new__, but
|
||||
/// that is also not supported by pybind11. Instead, we use this entry
|
||||
/// point which always constructs a fresh context (which cannot alias an
|
||||
/// existing one because it is fresh).
|
||||
@ -187,17 +190,17 @@ public:
|
||||
/// Gets a strong reference to this context, which will ensure it is kept
|
||||
/// alive for the life of the reference.
|
||||
PyMlirContextRef getRef() {
|
||||
return PyMlirContextRef(this, pybind11::cast(this));
|
||||
return PyMlirContextRef(this, nanobind::cast(this));
|
||||
}
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirContext.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
|
||||
/// Note that PyMlirContext instances are uniqued, so the returned object
|
||||
/// may be a pre-existing object. Ownership of the underlying MlirContext
|
||||
/// is taken by calling this function.
|
||||
static pybind11::object createFromCapsule(pybind11::object capsule);
|
||||
static nanobind::object createFromCapsule(nanobind::object capsule);
|
||||
|
||||
/// Gets the count of live context objects. Used for testing.
|
||||
static size_t getLiveCount();
|
||||
@ -237,14 +240,14 @@ public:
|
||||
size_t getLiveModuleCount();
|
||||
|
||||
/// Enter and exit the context manager.
|
||||
pybind11::object contextEnter();
|
||||
void contextExit(const pybind11::object &excType,
|
||||
const pybind11::object &excVal,
|
||||
const pybind11::object &excTb);
|
||||
static nanobind::object contextEnter(nanobind::object context);
|
||||
void contextExit(const nanobind::object &excType,
|
||||
const nanobind::object &excVal,
|
||||
const nanobind::object &excTb);
|
||||
|
||||
/// Attaches a Python callback as a diagnostic handler, returning a
|
||||
/// registration object (internally a PyDiagnosticHandler).
|
||||
pybind11::object attachDiagnosticHandler(pybind11::object callback);
|
||||
nanobind::object attachDiagnosticHandler(nanobind::object callback);
|
||||
|
||||
/// Controls whether error diagnostics should be propagated to diagnostic
|
||||
/// handlers, instead of being captured by `ErrorCapture`.
|
||||
@ -252,8 +255,6 @@ public:
|
||||
struct ErrorCapture;
|
||||
|
||||
private:
|
||||
PyMlirContext(MlirContext context);
|
||||
|
||||
// Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
|
||||
// preserving the relationship that an MlirContext maps to a single
|
||||
// PyMlirContext wrapper. This could be replaced in the future with an
|
||||
@ -268,7 +269,7 @@ private:
|
||||
// from this map, and while it still exists as an instance, any
|
||||
// attempt to access it will raise an error.
|
||||
using LiveModuleMap =
|
||||
llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
|
||||
llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>;
|
||||
LiveModuleMap liveModules;
|
||||
|
||||
// Interns all live operations associated with this context. Operations
|
||||
@ -276,7 +277,7 @@ private:
|
||||
// removed from this map, and while it still exists as an instance, any
|
||||
// attempt to access it will raise an error.
|
||||
using LiveOperationMap =
|
||||
llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
|
||||
llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>;
|
||||
LiveOperationMap liveOperations;
|
||||
|
||||
bool emitErrorDiagnostics = false;
|
||||
@ -324,19 +325,19 @@ public:
|
||||
MlirLocation get() const { return loc; }
|
||||
|
||||
/// Enter and exit the context manager.
|
||||
pybind11::object contextEnter();
|
||||
void contextExit(const pybind11::object &excType,
|
||||
const pybind11::object &excVal,
|
||||
const pybind11::object &excTb);
|
||||
static nanobind::object contextEnter(nanobind::object location);
|
||||
void contextExit(const nanobind::object &excType,
|
||||
const nanobind::object &excVal,
|
||||
const nanobind::object &excTb);
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirLocation.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyLocation from the MlirLocation wrapped by a capsule.
|
||||
/// Note that PyLocation instances are uniqued, so the returned object
|
||||
/// may be a pre-existing object. Ownership of the underlying MlirLocation
|
||||
/// is taken by calling this function.
|
||||
static PyLocation createFromCapsule(pybind11::object capsule);
|
||||
static PyLocation createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
MlirLocation loc;
|
||||
@ -353,8 +354,8 @@ public:
|
||||
bool isValid() { return valid; }
|
||||
MlirDiagnosticSeverity getSeverity();
|
||||
PyLocation getLocation();
|
||||
pybind11::str getMessage();
|
||||
pybind11::tuple getNotes();
|
||||
nanobind::str getMessage();
|
||||
nanobind::tuple getNotes();
|
||||
|
||||
/// Materialized diagnostic information. This is safe to access outside the
|
||||
/// diagnostic callback.
|
||||
@ -373,7 +374,7 @@ private:
|
||||
/// If notes have been materialized from the diagnostic, then this will
|
||||
/// be populated with the corresponding objects (all castable to
|
||||
/// PyDiagnostic).
|
||||
std::optional<pybind11::tuple> materializedNotes;
|
||||
std::optional<nanobind::tuple> materializedNotes;
|
||||
bool valid = true;
|
||||
};
|
||||
|
||||
@ -398,7 +399,7 @@ private:
|
||||
/// is no way to attach an existing handler object).
|
||||
class PyDiagnosticHandler {
|
||||
public:
|
||||
PyDiagnosticHandler(MlirContext context, pybind11::object callback);
|
||||
PyDiagnosticHandler(MlirContext context, nanobind::object callback);
|
||||
~PyDiagnosticHandler();
|
||||
|
||||
bool isAttached() { return registeredID.has_value(); }
|
||||
@ -407,16 +408,16 @@ public:
|
||||
/// Detaches the handler. Does nothing if not attached.
|
||||
void detach();
|
||||
|
||||
pybind11::object contextEnter() { return pybind11::cast(this); }
|
||||
void contextExit(const pybind11::object &excType,
|
||||
const pybind11::object &excVal,
|
||||
const pybind11::object &excTb) {
|
||||
nanobind::object contextEnter() { return nanobind::cast(this); }
|
||||
void contextExit(const nanobind::object &excType,
|
||||
const nanobind::object &excVal,
|
||||
const nanobind::object &excTb) {
|
||||
detach();
|
||||
}
|
||||
|
||||
private:
|
||||
MlirContext context;
|
||||
pybind11::object callback;
|
||||
nanobind::object callback;
|
||||
std::optional<MlirDiagnosticHandlerID> registeredID;
|
||||
bool hadError = false;
|
||||
friend class PyMlirContext;
|
||||
@ -477,12 +478,12 @@ public:
|
||||
/// objects of this type will be returned directly.
|
||||
class PyDialect {
|
||||
public:
|
||||
PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
|
||||
PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {}
|
||||
|
||||
pybind11::object getDescriptor() { return descriptor; }
|
||||
nanobind::object getDescriptor() { return descriptor; }
|
||||
|
||||
private:
|
||||
pybind11::object descriptor;
|
||||
nanobind::object descriptor;
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirDialectRegistry.
|
||||
@ -505,8 +506,8 @@ public:
|
||||
operator MlirDialectRegistry() const { return registry; }
|
||||
MlirDialectRegistry get() const { return registry; }
|
||||
|
||||
pybind11::object getCapsule();
|
||||
static PyDialectRegistry createFromCapsule(pybind11::object capsule);
|
||||
nanobind::object getCapsule();
|
||||
static PyDialectRegistry createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
MlirDialectRegistry registry;
|
||||
@ -542,26 +543,25 @@ public:
|
||||
|
||||
/// Gets a strong reference to this module.
|
||||
PyModuleRef getRef() {
|
||||
return PyModuleRef(this,
|
||||
pybind11::reinterpret_borrow<pybind11::object>(handle));
|
||||
return PyModuleRef(this, nanobind::borrow<nanobind::object>(handle));
|
||||
}
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirModule.
|
||||
/// Note that the module does not (yet) provide a corresponding factory for
|
||||
/// constructing from a capsule as that would require uniquing PyModule
|
||||
/// instances, which is not currently done.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyModule from the MlirModule wrapped by a capsule.
|
||||
/// Note that PyModule instances are uniqued, so the returned object
|
||||
/// may be a pre-existing object. Ownership of the underlying MlirModule
|
||||
/// is taken by calling this function.
|
||||
static pybind11::object createFromCapsule(pybind11::object capsule);
|
||||
static nanobind::object createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
PyModule(PyMlirContextRef contextRef, MlirModule module);
|
||||
MlirModule module;
|
||||
pybind11::handle handle;
|
||||
nanobind::handle handle;
|
||||
};
|
||||
|
||||
class PyAsmState;
|
||||
@ -574,18 +574,18 @@ public:
|
||||
/// Implements the bound 'print' method and helps with others.
|
||||
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
|
||||
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
|
||||
bool assumeVerified, py::object fileObject, bool binary,
|
||||
bool assumeVerified, nanobind::object fileObject, bool binary,
|
||||
bool skipRegions);
|
||||
void print(PyAsmState &state, py::object fileObject, bool binary);
|
||||
void print(PyAsmState &state, nanobind::object fileObject, bool binary);
|
||||
|
||||
pybind11::object getAsm(bool binary,
|
||||
nanobind::object getAsm(bool binary,
|
||||
std::optional<int64_t> largeElementsLimit,
|
||||
bool enableDebugInfo, bool prettyDebugInfo,
|
||||
bool printGenericOpForm, bool useLocalScope,
|
||||
bool assumeVerified, bool skipRegions);
|
||||
|
||||
// Implement the bound 'writeBytecode' method.
|
||||
void writeBytecode(const pybind11::object &fileObject,
|
||||
void writeBytecode(const nanobind::object &fileObject,
|
||||
std::optional<int64_t> bytecodeVersion);
|
||||
|
||||
// Implement the walk method.
|
||||
@ -621,13 +621,13 @@ public:
|
||||
/// it with a parentKeepAlive.
|
||||
static PyOperationRef
|
||||
forOperation(PyMlirContextRef contextRef, MlirOperation operation,
|
||||
pybind11::object parentKeepAlive = pybind11::object());
|
||||
nanobind::object parentKeepAlive = nanobind::object());
|
||||
|
||||
/// Creates a detached operation. The operation must not be associated with
|
||||
/// any existing live operation.
|
||||
static PyOperationRef
|
||||
createDetached(PyMlirContextRef contextRef, MlirOperation operation,
|
||||
pybind11::object parentKeepAlive = pybind11::object());
|
||||
nanobind::object parentKeepAlive = nanobind::object());
|
||||
|
||||
/// Parses a source string (either text assembly or bytecode), creating a
|
||||
/// detached operation.
|
||||
@ -640,7 +640,7 @@ public:
|
||||
void detachFromParent() {
|
||||
mlirOperationRemoveFromParent(getOperation());
|
||||
setDetached();
|
||||
parentKeepAlive = pybind11::object();
|
||||
parentKeepAlive = nanobind::object();
|
||||
}
|
||||
|
||||
/// Gets the backing operation.
|
||||
@ -651,12 +651,11 @@ public:
|
||||
}
|
||||
|
||||
PyOperationRef getRef() {
|
||||
return PyOperationRef(
|
||||
this, pybind11::reinterpret_borrow<pybind11::object>(handle));
|
||||
return PyOperationRef(this, nanobind::borrow<nanobind::object>(handle));
|
||||
}
|
||||
|
||||
bool isAttached() { return attached; }
|
||||
void setAttached(const pybind11::object &parent = pybind11::object()) {
|
||||
void setAttached(const nanobind::object &parent = nanobind::object()) {
|
||||
assert(!attached && "operation already attached");
|
||||
attached = true;
|
||||
}
|
||||
@ -675,24 +674,24 @@ public:
|
||||
std::optional<PyOperationRef> getParentOperation();
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirOperation.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyOperation from the MlirOperation wrapped by a capsule.
|
||||
/// Ownership of the underlying MlirOperation is taken by calling this
|
||||
/// function.
|
||||
static pybind11::object createFromCapsule(pybind11::object capsule);
|
||||
static nanobind::object createFromCapsule(nanobind::object capsule);
|
||||
|
||||
/// Creates an operation. See corresponding python docstring.
|
||||
static pybind11::object
|
||||
static nanobind::object
|
||||
create(const std::string &name, std::optional<std::vector<PyType *>> results,
|
||||
std::optional<std::vector<PyValue *>> operands,
|
||||
std::optional<pybind11::dict> attributes,
|
||||
std::optional<nanobind::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors, int regions,
|
||||
DefaultingPyLocation location, const pybind11::object &ip,
|
||||
DefaultingPyLocation location, const nanobind::object &ip,
|
||||
bool inferType);
|
||||
|
||||
/// Creates an OpView suitable for this operation.
|
||||
pybind11::object createOpView();
|
||||
nanobind::object createOpView();
|
||||
|
||||
/// Erases the underlying MlirOperation, removes its pointer from the
|
||||
/// parent context's live operations map, and sets the valid bit false.
|
||||
@ -702,23 +701,23 @@ public:
|
||||
void setInvalid() { valid = false; }
|
||||
|
||||
/// Clones this operation.
|
||||
pybind11::object clone(const pybind11::object &ip);
|
||||
nanobind::object clone(const nanobind::object &ip);
|
||||
|
||||
private:
|
||||
PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
|
||||
static PyOperationRef createInstance(PyMlirContextRef contextRef,
|
||||
MlirOperation operation,
|
||||
pybind11::object parentKeepAlive);
|
||||
nanobind::object parentKeepAlive);
|
||||
|
||||
MlirOperation operation;
|
||||
pybind11::handle handle;
|
||||
nanobind::handle handle;
|
||||
// Keeps the parent alive, regardless of whether it is an Operation or
|
||||
// Module.
|
||||
// TODO: As implemented, this facility is only sufficient for modeling the
|
||||
// trivial module parent back-reference. Generalize this to also account for
|
||||
// transitions from detached to attached and address TODOs in the
|
||||
// ir_operation.py regarding testing corresponding lifetime guarantees.
|
||||
pybind11::object parentKeepAlive;
|
||||
nanobind::object parentKeepAlive;
|
||||
bool attached = true;
|
||||
bool valid = true;
|
||||
|
||||
@ -733,17 +732,17 @@ private:
|
||||
/// python types.
|
||||
class PyOpView : public PyOperationBase {
|
||||
public:
|
||||
PyOpView(const pybind11::object &operationObject);
|
||||
PyOpView(const nanobind::object &operationObject);
|
||||
PyOperation &getOperation() override { return operation; }
|
||||
|
||||
pybind11::object getOperationObject() { return operationObject; }
|
||||
nanobind::object getOperationObject() { return operationObject; }
|
||||
|
||||
static pybind11::object buildGeneric(
|
||||
const pybind11::object &cls, std::optional<pybind11::list> resultTypeList,
|
||||
pybind11::list operandList, std::optional<pybind11::dict> attributes,
|
||||
static nanobind::object buildGeneric(
|
||||
const nanobind::object &cls, std::optional<nanobind::list> resultTypeList,
|
||||
nanobind::list operandList, std::optional<nanobind::dict> attributes,
|
||||
std::optional<std::vector<PyBlock *>> successors,
|
||||
std::optional<int> regions, DefaultingPyLocation location,
|
||||
const pybind11::object &maybeIp);
|
||||
const nanobind::object &maybeIp);
|
||||
|
||||
/// Construct an instance of a class deriving from OpView, bypassing its
|
||||
/// `__init__` method. The derived class will typically define a constructor
|
||||
@ -752,12 +751,12 @@ public:
|
||||
///
|
||||
/// The caller is responsible for verifying that `operation` is a valid
|
||||
/// operation to construct `cls` with.
|
||||
static pybind11::object constructDerived(const pybind11::object &cls,
|
||||
const PyOperation &operation);
|
||||
static nanobind::object constructDerived(const nanobind::object &cls,
|
||||
const nanobind::object &operation);
|
||||
|
||||
private:
|
||||
PyOperation &operation; // For efficient, cast-free access from C++
|
||||
pybind11::object operationObject; // Holds the reference.
|
||||
nanobind::object operationObject; // Holds the reference.
|
||||
};
|
||||
|
||||
/// Wrapper around an MlirRegion.
|
||||
@ -830,7 +829,7 @@ public:
|
||||
void checkValid() { return parentOperation->checkValid(); }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirBlock.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
private:
|
||||
PyOperationRef parentOperation;
|
||||
@ -858,10 +857,10 @@ public:
|
||||
void insert(PyOperationBase &operationBase);
|
||||
|
||||
/// Enter and exit the context manager.
|
||||
pybind11::object contextEnter();
|
||||
void contextExit(const pybind11::object &excType,
|
||||
const pybind11::object &excVal,
|
||||
const pybind11::object &excTb);
|
||||
static nanobind::object contextEnter(nanobind::object insertionPoint);
|
||||
void contextExit(const nanobind::object &excType,
|
||||
const nanobind::object &excVal,
|
||||
const nanobind::object &excTb);
|
||||
|
||||
PyBlock &getBlock() { return block; }
|
||||
std::optional<PyOperationRef> &getRefOperation() { return refOperation; }
|
||||
@ -886,13 +885,13 @@ public:
|
||||
MlirType get() const { return type; }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirType.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyType from the MlirType wrapped by a capsule.
|
||||
/// Note that PyType instances are uniqued, so the returned object
|
||||
/// may be a pre-existing object. Ownership of the underlying MlirType
|
||||
/// is taken by calling this function.
|
||||
static PyType createFromCapsule(pybind11::object capsule);
|
||||
static PyType createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
MlirType type;
|
||||
@ -912,10 +911,10 @@ public:
|
||||
MlirTypeID get() { return typeID; }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirTypeID.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyTypeID from the MlirTypeID wrapped by a capsule.
|
||||
static PyTypeID createFromCapsule(pybind11::object capsule);
|
||||
static PyTypeID createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
MlirTypeID typeID;
|
||||
@ -932,7 +931,7 @@ public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
|
||||
using ClassTy = nanobind::class_<DerivedTy, BaseTy>;
|
||||
using IsAFunctionTy = bool (*)(MlirType);
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
|
||||
@ -945,34 +944,38 @@ public:
|
||||
|
||||
static MlirType castFrom(PyType &orig) {
|
||||
if (!DerivedTy::isaFunction(orig)) {
|
||||
auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
|
||||
throw py::value_error((llvm::Twine("Cannot cast type to ") +
|
||||
DerivedTy::pyClassName + " (from " + origRepr +
|
||||
")")
|
||||
.str());
|
||||
auto origRepr =
|
||||
nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
|
||||
throw nanobind::value_error((llvm::Twine("Cannot cast type to ") +
|
||||
DerivedTy::pyClassName + " (from " +
|
||||
origRepr + ")")
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
return orig;
|
||||
}
|
||||
|
||||
static void bind(pybind11::module &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
|
||||
cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(),
|
||||
pybind11::arg("cast_from_type"));
|
||||
static void bind(nanobind::module_ &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
cls.def(nanobind::init<PyType &>(), nanobind::keep_alive<0, 1>(),
|
||||
nanobind::arg("cast_from_type"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyType &otherType) -> bool {
|
||||
return DerivedTy::isaFunction(otherType);
|
||||
},
|
||||
pybind11::arg("other"));
|
||||
cls.def_property_readonly_static(
|
||||
"static_typeid", [](py::object & /*class*/) -> MlirTypeID {
|
||||
nanobind::arg("other"));
|
||||
cls.def_prop_ro_static(
|
||||
"static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID {
|
||||
if (DerivedTy::getTypeIdFunction)
|
||||
return DerivedTy::getTypeIdFunction();
|
||||
throw py::attribute_error(
|
||||
(DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
|
||||
throw nanobind::attribute_error(
|
||||
(DerivedTy::pyClassName + llvm::Twine(" has no typeid."))
|
||||
.str()
|
||||
.c_str());
|
||||
});
|
||||
cls.def_property_readonly("typeid", [](PyType &self) {
|
||||
return py::cast(self).attr("typeid").cast<MlirTypeID>();
|
||||
cls.def_prop_ro("typeid", [](PyType &self) {
|
||||
return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid"));
|
||||
});
|
||||
cls.def("__repr__", [](DerivedTy &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
@ -986,8 +989,8 @@ public:
|
||||
if (DerivedTy::getTypeIdFunction) {
|
||||
PyGlobals::get().registerTypeCaster(
|
||||
DerivedTy::getTypeIdFunction(),
|
||||
pybind11::cpp_function(
|
||||
[](PyType pyType) -> DerivedTy { return pyType; }));
|
||||
nanobind::cast<nanobind::callable>(nanobind::cpp_function(
|
||||
[](PyType pyType) -> DerivedTy { return pyType; })));
|
||||
}
|
||||
|
||||
DerivedTy::bindDerived(cls);
|
||||
@ -1008,13 +1011,13 @@ public:
|
||||
MlirAttribute get() const { return attr; }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirAttribute.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
|
||||
/// Note that PyAttribute instances are uniqued, so the returned object
|
||||
/// may be a pre-existing object. Ownership of the underlying MlirAttribute
|
||||
/// is taken by calling this function.
|
||||
static PyAttribute createFromCapsule(pybind11::object capsule);
|
||||
static PyAttribute createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
MlirAttribute attr;
|
||||
@ -1054,7 +1057,7 @@ public:
|
||||
// Derived classes must define statics for:
|
||||
// IsAFunctionTy isaFunction
|
||||
// const char *pyClassName
|
||||
using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
|
||||
using ClassTy = nanobind::class_<DerivedTy, BaseTy>;
|
||||
using IsAFunctionTy = bool (*)(MlirAttribute);
|
||||
using GetTypeIDFunctionTy = MlirTypeID (*)();
|
||||
static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
|
||||
@ -1067,37 +1070,45 @@ public:
|
||||
|
||||
static MlirAttribute castFrom(PyAttribute &orig) {
|
||||
if (!DerivedTy::isaFunction(orig)) {
|
||||
auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
|
||||
throw py::value_error((llvm::Twine("Cannot cast attribute to ") +
|
||||
DerivedTy::pyClassName + " (from " + origRepr +
|
||||
")")
|
||||
.str());
|
||||
auto origRepr =
|
||||
nanobind::cast<std::string>(nanobind::repr(nanobind::cast(orig)));
|
||||
throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") +
|
||||
DerivedTy::pyClassName + " (from " +
|
||||
origRepr + ")")
|
||||
.str()
|
||||
.c_str());
|
||||
}
|
||||
return orig;
|
||||
}
|
||||
|
||||
static void bind(pybind11::module &m) {
|
||||
auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
|
||||
pybind11::module_local());
|
||||
cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(),
|
||||
pybind11::arg("cast_from_attr"));
|
||||
static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) {
|
||||
ClassTy cls;
|
||||
if (slots) {
|
||||
cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots));
|
||||
} else {
|
||||
cls = ClassTy(m, DerivedTy::pyClassName);
|
||||
}
|
||||
cls.def(nanobind::init<PyAttribute &>(), nanobind::keep_alive<0, 1>(),
|
||||
nanobind::arg("cast_from_attr"));
|
||||
cls.def_static(
|
||||
"isinstance",
|
||||
[](PyAttribute &otherAttr) -> bool {
|
||||
return DerivedTy::isaFunction(otherAttr);
|
||||
},
|
||||
pybind11::arg("other"));
|
||||
cls.def_property_readonly(
|
||||
nanobind::arg("other"));
|
||||
cls.def_prop_ro(
|
||||
"type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); });
|
||||
cls.def_property_readonly_static(
|
||||
"static_typeid", [](py::object & /*class*/) -> MlirTypeID {
|
||||
cls.def_prop_ro_static(
|
||||
"static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID {
|
||||
if (DerivedTy::getTypeIdFunction)
|
||||
return DerivedTy::getTypeIdFunction();
|
||||
throw py::attribute_error(
|
||||
(DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str());
|
||||
throw nanobind::attribute_error(
|
||||
(DerivedTy::pyClassName + llvm::Twine(" has no typeid."))
|
||||
.str()
|
||||
.c_str());
|
||||
});
|
||||
cls.def_property_readonly("typeid", [](PyAttribute &self) {
|
||||
return py::cast(self).attr("typeid").cast<MlirTypeID>();
|
||||
cls.def_prop_ro("typeid", [](PyAttribute &self) {
|
||||
return nanobind::cast<MlirTypeID>(nanobind::cast(self).attr("typeid"));
|
||||
});
|
||||
cls.def("__repr__", [](DerivedTy &self) {
|
||||
PyPrintAccumulator printAccum;
|
||||
@ -1112,9 +1123,10 @@ public:
|
||||
if (DerivedTy::getTypeIdFunction) {
|
||||
PyGlobals::get().registerTypeCaster(
|
||||
DerivedTy::getTypeIdFunction(),
|
||||
pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
|
||||
return pyAttribute;
|
||||
}));
|
||||
nanobind::cast<nanobind::callable>(
|
||||
nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
|
||||
return pyAttribute;
|
||||
})));
|
||||
}
|
||||
|
||||
DerivedTy::bindDerived(cls);
|
||||
@ -1146,13 +1158,13 @@ public:
|
||||
void checkValid() { return parentOperation->checkValid(); }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirValue.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
pybind11::object maybeDownCast();
|
||||
nanobind::object maybeDownCast();
|
||||
|
||||
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
|
||||
/// the underlying MlirValue is still tied to the owning operation.
|
||||
static PyValue createFromCapsule(pybind11::object capsule);
|
||||
static PyValue createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
PyOperationRef parentOperation;
|
||||
@ -1169,13 +1181,13 @@ public:
|
||||
MlirAffineExpr get() const { return affineExpr; }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirAffineExpr.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
|
||||
/// Note that PyAffineExpr instances are uniqued, so the returned object
|
||||
/// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
|
||||
/// is taken by calling this function.
|
||||
static PyAffineExpr createFromCapsule(pybind11::object capsule);
|
||||
static PyAffineExpr createFromCapsule(nanobind::object capsule);
|
||||
|
||||
PyAffineExpr add(const PyAffineExpr &other) const;
|
||||
PyAffineExpr mul(const PyAffineExpr &other) const;
|
||||
@ -1196,13 +1208,13 @@ public:
|
||||
MlirAffineMap get() const { return affineMap; }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirAffineMap.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
|
||||
/// Note that PyAffineMap instances are uniqued, so the returned object
|
||||
/// may be a pre-existing object. Ownership of the underlying MlirAffineMap
|
||||
/// is taken by calling this function.
|
||||
static PyAffineMap createFromCapsule(pybind11::object capsule);
|
||||
static PyAffineMap createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
MlirAffineMap affineMap;
|
||||
@ -1217,12 +1229,12 @@ public:
|
||||
MlirIntegerSet get() const { return integerSet; }
|
||||
|
||||
/// Gets a capsule wrapping the void* within the MlirIntegerSet.
|
||||
pybind11::object getCapsule();
|
||||
nanobind::object getCapsule();
|
||||
|
||||
/// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
|
||||
/// Note that PyIntegerSet instances may be uniqued, so the returned object
|
||||
/// may be a pre-existing object. Integer sets are owned by the context.
|
||||
static PyIntegerSet createFromCapsule(pybind11::object capsule);
|
||||
static PyIntegerSet createFromCapsule(nanobind::object capsule);
|
||||
|
||||
private:
|
||||
MlirIntegerSet integerSet;
|
||||
@ -1239,7 +1251,7 @@ public:
|
||||
|
||||
/// Returns the symbol (opview) with the given name, throws if there is no
|
||||
/// such symbol in the table.
|
||||
pybind11::object dunderGetItem(const std::string &name);
|
||||
nanobind::object dunderGetItem(const std::string &name);
|
||||
|
||||
/// Removes the given operation from the symbol table and erases it.
|
||||
void erase(PyOperationBase &symbol);
|
||||
@ -1269,7 +1281,7 @@ public:
|
||||
|
||||
/// Walks all symbol tables under and including 'from'.
|
||||
static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
|
||||
pybind11::object callback);
|
||||
nanobind::object callback);
|
||||
|
||||
/// Casts the bindings class into the C API structure.
|
||||
operator MlirSymbolTable() { return symbolTable; }
|
||||
@ -1289,16 +1301,16 @@ struct MLIRError {
|
||||
std::vector<PyDiagnostic::DiagnosticInfo> errorDiagnostics;
|
||||
};
|
||||
|
||||
void populateIRAffine(pybind11::module &m);
|
||||
void populateIRAttributes(pybind11::module &m);
|
||||
void populateIRCore(pybind11::module &m);
|
||||
void populateIRInterfaces(pybind11::module &m);
|
||||
void populateIRTypes(pybind11::module &m);
|
||||
void populateIRAffine(nanobind::module_ &m);
|
||||
void populateIRAttributes(nanobind::module_ &m);
|
||||
void populateIRCore(nanobind::module_ &m);
|
||||
void populateIRInterfaces(nanobind::module_ &m);
|
||||
void populateIRTypes(nanobind::module_ &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
||||
namespace pybind11 {
|
||||
namespace nanobind {
|
||||
namespace detail {
|
||||
|
||||
template <>
|
||||
@ -1309,6 +1321,6 @@ struct type_caster<mlir::python::DefaultingPyLocation>
|
||||
: MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
} // namespace nanobind
|
||||
|
||||
#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
|
||||
|
@ -6,19 +6,26 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// clang-format off
|
||||
#include "IRModule.h"
|
||||
|
||||
#include "PybindUtils.h"
|
||||
|
||||
#include "mlir/Bindings/Python/IRTypes.h"
|
||||
// clang-format on
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/pair.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
#include <nanobind/stl/vector.h>
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "IRModule.h"
|
||||
#include "NanobindUtils.h"
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
||||
@ -48,7 +55,7 @@ public:
|
||||
MlirType t = mlirIntegerTypeGet(context->get(), width);
|
||||
return PyIntegerType(context->getRef(), t);
|
||||
},
|
||||
py::arg("width"), py::arg("context") = py::none(),
|
||||
nb::arg("width"), nb::arg("context").none() = nb::none(),
|
||||
"Create a signless integer type");
|
||||
c.def_static(
|
||||
"get_signed",
|
||||
@ -56,7 +63,7 @@ public:
|
||||
MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
|
||||
return PyIntegerType(context->getRef(), t);
|
||||
},
|
||||
py::arg("width"), py::arg("context") = py::none(),
|
||||
nb::arg("width"), nb::arg("context").none() = nb::none(),
|
||||
"Create a signed integer type");
|
||||
c.def_static(
|
||||
"get_unsigned",
|
||||
@ -64,25 +71,25 @@ public:
|
||||
MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
|
||||
return PyIntegerType(context->getRef(), t);
|
||||
},
|
||||
py::arg("width"), py::arg("context") = py::none(),
|
||||
nb::arg("width"), nb::arg("context").none() = nb::none(),
|
||||
"Create an unsigned integer type");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"width",
|
||||
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
|
||||
"Returns the width of the integer type");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"is_signless",
|
||||
[](PyIntegerType &self) -> bool {
|
||||
return mlirIntegerTypeIsSignless(self);
|
||||
},
|
||||
"Returns whether this is a signless integer");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"is_signed",
|
||||
[](PyIntegerType &self) -> bool {
|
||||
return mlirIntegerTypeIsSigned(self);
|
||||
},
|
||||
"Returns whether this is a signed integer");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"is_unsigned",
|
||||
[](PyIntegerType &self) -> bool {
|
||||
return mlirIntegerTypeIsUnsigned(self);
|
||||
@ -107,7 +114,7 @@ public:
|
||||
MlirType t = mlirIndexTypeGet(context->get());
|
||||
return PyIndexType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a index type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a index type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -118,7 +125,7 @@ public:
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
|
||||
"Returns the width of the floating-point type");
|
||||
}
|
||||
@ -141,7 +148,7 @@ public:
|
||||
MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
|
||||
return PyFloat4E2M1FNType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float4_e2m1fn type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -162,7 +169,7 @@ public:
|
||||
MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
|
||||
return PyFloat6E2M3FNType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float6_e2m3fn type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -183,7 +190,7 @@ public:
|
||||
MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
|
||||
return PyFloat6E3M2FNType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float6_e3m2fn type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -204,7 +211,7 @@ public:
|
||||
MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
|
||||
return PyFloat8E4M3FNType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e4m3fn type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -224,7 +231,7 @@ public:
|
||||
MlirType t = mlirFloat8E5M2TypeGet(context->get());
|
||||
return PyFloat8E5M2Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e5m2 type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -244,7 +251,7 @@ public:
|
||||
MlirType t = mlirFloat8E4M3TypeGet(context->get());
|
||||
return PyFloat8E4M3Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e4m3 type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -265,7 +272,8 @@ public:
|
||||
MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
|
||||
return PyFloat8E4M3FNUZType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e4m3fnuz type.");
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Create a float8_e4m3fnuz type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -286,7 +294,8 @@ public:
|
||||
MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
|
||||
return PyFloat8E4M3B11FNUZType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type.");
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Create a float8_e4m3b11fnuz type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -307,7 +316,8 @@ public:
|
||||
MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
|
||||
return PyFloat8E5M2FNUZType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e5m2fnuz type.");
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Create a float8_e5m2fnuz type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -327,7 +337,7 @@ public:
|
||||
MlirType t = mlirFloat8E3M4TypeGet(context->get());
|
||||
return PyFloat8E3M4Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e3m4 type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -348,7 +358,8 @@ public:
|
||||
MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
|
||||
return PyFloat8E8M0FNUType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a float8_e8m0fnu type.");
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Create a float8_e8m0fnu type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -368,7 +379,7 @@ public:
|
||||
MlirType t = mlirBF16TypeGet(context->get());
|
||||
return PyBF16Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a bf16 type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a bf16 type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -388,7 +399,7 @@ public:
|
||||
MlirType t = mlirF16TypeGet(context->get());
|
||||
return PyF16Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a f16 type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a f16 type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -408,7 +419,7 @@ public:
|
||||
MlirType t = mlirTF32TypeGet(context->get());
|
||||
return PyTF32Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a tf32 type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a tf32 type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -428,7 +439,7 @@ public:
|
||||
MlirType t = mlirF32TypeGet(context->get());
|
||||
return PyF32Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a f32 type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a f32 type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -448,7 +459,7 @@ public:
|
||||
MlirType t = mlirF64TypeGet(context->get());
|
||||
return PyF64Type(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a f64 type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a f64 type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -468,7 +479,7 @@ public:
|
||||
MlirType t = mlirNoneTypeGet(context->get());
|
||||
return PyNoneType(context->getRef(), t);
|
||||
},
|
||||
py::arg("context") = py::none(), "Create a none type.");
|
||||
nb::arg("context").none() = nb::none(), "Create a none type.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -490,14 +501,15 @@ public:
|
||||
MlirType t = mlirComplexTypeGet(elementType);
|
||||
return PyComplexType(elementType.getContext(), t);
|
||||
}
|
||||
throw py::value_error(
|
||||
throw nb::value_error(
|
||||
(Twine("invalid '") +
|
||||
py::repr(py::cast(elementType)).cast<std::string>() +
|
||||
nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
|
||||
"' and expected floating point or integer type.")
|
||||
.str());
|
||||
.str()
|
||||
.c_str());
|
||||
},
|
||||
"Create a complex type");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"element_type",
|
||||
[](PyComplexType &self) { return mlirComplexTypeGetElementType(self); },
|
||||
"Returns element type.");
|
||||
@ -508,22 +520,22 @@ public:
|
||||
|
||||
// Shaped Type Interface - ShapedType
|
||||
void mlir::PyShapedType::bindDerived(ClassTy &c) {
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"element_type",
|
||||
[](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
|
||||
"Returns the element type of the shaped type.");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"has_rank",
|
||||
[](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
|
||||
"Returns whether the given shaped type is ranked.");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"rank",
|
||||
[](PyShapedType &self) {
|
||||
self.requireHasRank();
|
||||
return mlirShapedTypeGetRank(self);
|
||||
},
|
||||
"Returns the rank of the given ranked shaped type.");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"has_static_shape",
|
||||
[](PyShapedType &self) -> bool {
|
||||
return mlirShapedTypeHasStaticShape(self);
|
||||
@ -535,7 +547,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
|
||||
self.requireHasRank();
|
||||
return mlirShapedTypeIsDynamicDim(self, dim);
|
||||
},
|
||||
py::arg("dim"),
|
||||
nb::arg("dim"),
|
||||
"Returns whether the dim-th dimension of the given shaped type is "
|
||||
"dynamic.");
|
||||
c.def(
|
||||
@ -544,12 +556,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
|
||||
self.requireHasRank();
|
||||
return mlirShapedTypeGetDimSize(self, dim);
|
||||
},
|
||||
py::arg("dim"),
|
||||
nb::arg("dim"),
|
||||
"Returns the dim-th dimension of the given ranked shaped type.");
|
||||
c.def_static(
|
||||
"is_dynamic_size",
|
||||
[](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
|
||||
py::arg("dim_size"),
|
||||
nb::arg("dim_size"),
|
||||
"Returns whether the given dimension size indicates a dynamic "
|
||||
"dimension.");
|
||||
c.def(
|
||||
@ -558,10 +570,10 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
|
||||
self.requireHasRank();
|
||||
return mlirShapedTypeIsDynamicStrideOrOffset(val);
|
||||
},
|
||||
py::arg("dim_size"),
|
||||
nb::arg("dim_size"),
|
||||
"Returns whether the given value is used as a placeholder for dynamic "
|
||||
"strides and offsets in shaped types.");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"shape",
|
||||
[](PyShapedType &self) {
|
||||
self.requireHasRank();
|
||||
@ -587,7 +599,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) {
|
||||
|
||||
void mlir::PyShapedType::requireHasRank() {
|
||||
if (!mlirShapedTypeHasRank(*this)) {
|
||||
throw py::value_error(
|
||||
throw nb::value_error(
|
||||
"calling this method requires that the type has a rank.");
|
||||
}
|
||||
}
|
||||
@ -607,15 +619,15 @@ public:
|
||||
using PyConcreteType::PyConcreteType;
|
||||
|
||||
static void bindDerived(ClassTy &c) {
|
||||
c.def_static("get", &PyVectorType::get, py::arg("shape"),
|
||||
py::arg("element_type"), py::kw_only(),
|
||||
py::arg("scalable") = py::none(),
|
||||
py::arg("scalable_dims") = py::none(),
|
||||
py::arg("loc") = py::none(), "Create a vector type")
|
||||
.def_property_readonly(
|
||||
c.def_static("get", &PyVectorType::get, nb::arg("shape"),
|
||||
nb::arg("element_type"), nb::kw_only(),
|
||||
nb::arg("scalable").none() = nb::none(),
|
||||
nb::arg("scalable_dims").none() = nb::none(),
|
||||
nb::arg("loc").none() = nb::none(), "Create a vector type")
|
||||
.def_prop_ro(
|
||||
"scalable",
|
||||
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
|
||||
.def_property_readonly("scalable_dims", [](MlirType self) {
|
||||
.def_prop_ro("scalable_dims", [](MlirType self) {
|
||||
std::vector<bool> scalableDims;
|
||||
size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
|
||||
scalableDims.reserve(rank);
|
||||
@ -627,11 +639,11 @@ public:
|
||||
|
||||
private:
|
||||
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
|
||||
std::optional<py::list> scalable,
|
||||
std::optional<nb::list> scalable,
|
||||
std::optional<std::vector<int64_t>> scalableDims,
|
||||
DefaultingPyLocation loc) {
|
||||
if (scalable && scalableDims) {
|
||||
throw py::value_error("'scalable' and 'scalable_dims' kwargs "
|
||||
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
|
||||
"are mutually exclusive.");
|
||||
}
|
||||
|
||||
@ -639,10 +651,10 @@ private:
|
||||
MlirType type;
|
||||
if (scalable) {
|
||||
if (scalable->size() != shape.size())
|
||||
throw py::value_error("Expected len(scalable) == len(shape).");
|
||||
throw nb::value_error("Expected len(scalable) == len(shape).");
|
||||
|
||||
SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
|
||||
*scalable, [](const py::handle &h) { return h.cast<bool>(); }));
|
||||
*scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
|
||||
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
|
||||
scalableDimFlags.data(),
|
||||
elementType);
|
||||
@ -650,7 +662,7 @@ private:
|
||||
SmallVector<bool> scalableDimFlags(shape.size(), false);
|
||||
for (int64_t dim : *scalableDims) {
|
||||
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
|
||||
throw py::value_error("Scalable dimension index out of bounds.");
|
||||
throw nb::value_error("Scalable dimension index out of bounds.");
|
||||
scalableDimFlags[dim] = true;
|
||||
}
|
||||
type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
|
||||
@ -689,17 +701,17 @@ public:
|
||||
throw MLIRError("Invalid type", errors.take());
|
||||
return PyRankedTensorType(elementType.getContext(), t);
|
||||
},
|
||||
py::arg("shape"), py::arg("element_type"),
|
||||
py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
|
||||
"Create a ranked tensor type");
|
||||
c.def_property_readonly(
|
||||
"encoding",
|
||||
[](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
|
||||
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
|
||||
if (mlirAttributeIsNull(encoding))
|
||||
return std::nullopt;
|
||||
return encoding;
|
||||
});
|
||||
nb::arg("shape"), nb::arg("element_type"),
|
||||
nb::arg("encoding").none() = nb::none(),
|
||||
nb::arg("loc").none() = nb::none(), "Create a ranked tensor type");
|
||||
c.def_prop_ro("encoding",
|
||||
[](PyRankedTensorType &self) -> std::optional<MlirAttribute> {
|
||||
MlirAttribute encoding =
|
||||
mlirRankedTensorTypeGetEncoding(self.get());
|
||||
if (mlirAttributeIsNull(encoding))
|
||||
return std::nullopt;
|
||||
return encoding;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@ -723,7 +735,7 @@ public:
|
||||
throw MLIRError("Invalid type", errors.take());
|
||||
return PyUnrankedTensorType(elementType.getContext(), t);
|
||||
},
|
||||
py::arg("element_type"), py::arg("loc") = py::none(),
|
||||
nb::arg("element_type"), nb::arg("loc").none() = nb::none(),
|
||||
"Create a unranked tensor type");
|
||||
}
|
||||
};
|
||||
@ -754,10 +766,11 @@ public:
|
||||
throw MLIRError("Invalid type", errors.take());
|
||||
return PyMemRefType(elementType.getContext(), t);
|
||||
},
|
||||
py::arg("shape"), py::arg("element_type"),
|
||||
py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
|
||||
py::arg("loc") = py::none(), "Create a memref type")
|
||||
.def_property_readonly(
|
||||
nb::arg("shape"), nb::arg("element_type"),
|
||||
nb::arg("layout").none() = nb::none(),
|
||||
nb::arg("memory_space").none() = nb::none(),
|
||||
nb::arg("loc").none() = nb::none(), "Create a memref type")
|
||||
.def_prop_ro(
|
||||
"layout",
|
||||
[](PyMemRefType &self) -> MlirAttribute {
|
||||
return mlirMemRefTypeGetLayout(self);
|
||||
@ -775,14 +788,14 @@ public:
|
||||
return {strides, offset};
|
||||
},
|
||||
"The strides and offset of the MemRef type.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"affine_map",
|
||||
[](PyMemRefType &self) -> PyAffineMap {
|
||||
MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
|
||||
return PyAffineMap(self.getContext(), map);
|
||||
},
|
||||
"The layout of the MemRef type as an affine map.")
|
||||
.def_property_readonly(
|
||||
.def_prop_ro(
|
||||
"memory_space",
|
||||
[](PyMemRefType &self) -> std::optional<MlirAttribute> {
|
||||
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
|
||||
@ -820,9 +833,9 @@ public:
|
||||
throw MLIRError("Invalid type", errors.take());
|
||||
return PyUnrankedMemRefType(elementType.getContext(), t);
|
||||
},
|
||||
py::arg("element_type"), py::arg("memory_space"),
|
||||
py::arg("loc") = py::none(), "Create a unranked memref type")
|
||||
.def_property_readonly(
|
||||
nb::arg("element_type"), nb::arg("memory_space").none(),
|
||||
nb::arg("loc").none() = nb::none(), "Create a unranked memref type")
|
||||
.def_prop_ro(
|
||||
"memory_space",
|
||||
[](PyUnrankedMemRefType &self) -> std::optional<MlirAttribute> {
|
||||
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
|
||||
@ -851,15 +864,15 @@ public:
|
||||
elements.data());
|
||||
return PyTupleType(context->getRef(), t);
|
||||
},
|
||||
py::arg("elements"), py::arg("context") = py::none(),
|
||||
nb::arg("elements"), nb::arg("context").none() = nb::none(),
|
||||
"Create a tuple type");
|
||||
c.def(
|
||||
"get_type",
|
||||
[](PyTupleType &self, intptr_t pos) {
|
||||
return mlirTupleTypeGetType(self, pos);
|
||||
},
|
||||
py::arg("pos"), "Returns the pos-th type in the tuple type.");
|
||||
c.def_property_readonly(
|
||||
nb::arg("pos"), "Returns the pos-th type in the tuple type.");
|
||||
c.def_prop_ro(
|
||||
"num_types",
|
||||
[](PyTupleType &self) -> intptr_t {
|
||||
return mlirTupleTypeGetNumTypes(self);
|
||||
@ -887,13 +900,14 @@ public:
|
||||
results.size(), results.data());
|
||||
return PyFunctionType(context->getRef(), t);
|
||||
},
|
||||
py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
|
||||
nb::arg("inputs"), nb::arg("results"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Gets a FunctionType from a list of input and result types");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"inputs",
|
||||
[](PyFunctionType &self) {
|
||||
MlirType t = self;
|
||||
py::list types;
|
||||
nb::list types;
|
||||
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
|
||||
++i) {
|
||||
types.append(mlirFunctionTypeGetInput(t, i));
|
||||
@ -901,10 +915,10 @@ public:
|
||||
return types;
|
||||
},
|
||||
"Returns the list of input types in the FunctionType.");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"results",
|
||||
[](PyFunctionType &self) {
|
||||
py::list types;
|
||||
nb::list types;
|
||||
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
|
||||
++i) {
|
||||
types.append(mlirFunctionTypeGetResult(self, i));
|
||||
@ -938,21 +952,21 @@ public:
|
||||
toMlirStringRef(typeData));
|
||||
return PyOpaqueType(context->getRef(), type);
|
||||
},
|
||||
py::arg("dialect_namespace"), py::arg("buffer"),
|
||||
py::arg("context") = py::none(),
|
||||
nb::arg("dialect_namespace"), nb::arg("buffer"),
|
||||
nb::arg("context").none() = nb::none(),
|
||||
"Create an unregistered (opaque) dialect type.");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"dialect_namespace",
|
||||
[](PyOpaqueType &self) {
|
||||
MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
|
||||
return py::str(stringRef.data, stringRef.length);
|
||||
return nb::str(stringRef.data, stringRef.length);
|
||||
},
|
||||
"Returns the dialect namespace for the Opaque type as a string.");
|
||||
c.def_property_readonly(
|
||||
c.def_prop_ro(
|
||||
"data",
|
||||
[](PyOpaqueType &self) {
|
||||
MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
|
||||
return py::str(stringRef.data, stringRef.length);
|
||||
return nb::str(stringRef.data, stringRef.length);
|
||||
},
|
||||
"Returns the data for the Opaque type as a string.");
|
||||
}
|
||||
@ -960,7 +974,7 @@ public:
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::python::populateIRTypes(py::module &m) {
|
||||
void mlir::python::populateIRTypes(nb::module_ &m) {
|
||||
PyIntegerType::bind(m);
|
||||
PyFloatType::bind(m);
|
||||
PyIndexType::bind(m);
|
||||
|
@ -6,29 +6,31 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PybindUtils.h"
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include "Globals.h"
|
||||
#include "IRModule.h"
|
||||
#include "NanobindUtils.h"
|
||||
#include "Pass.h"
|
||||
#include "Rewrite.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlir;
|
||||
using namespace py::literals;
|
||||
using namespace nb::literals;
|
||||
using namespace mlir::python;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Module initialization.
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(_mlir, m) {
|
||||
NB_MODULE(_mlir, m) {
|
||||
m.doc() = "MLIR Python Native Extension";
|
||||
|
||||
py::class_<PyGlobals>(m, "_Globals", py::module_local())
|
||||
.def_property("dialect_search_modules",
|
||||
&PyGlobals::getDialectSearchPrefixes,
|
||||
&PyGlobals::setDialectSearchPrefixes)
|
||||
nb::class_<PyGlobals>(m, "_Globals")
|
||||
.def_prop_rw("dialect_search_modules",
|
||||
&PyGlobals::getDialectSearchPrefixes,
|
||||
&PyGlobals::setDialectSearchPrefixes)
|
||||
.def(
|
||||
"append_dialect_search_prefix",
|
||||
[](PyGlobals &self, std::string moduleName) {
|
||||
@ -45,22 +47,21 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
"dialect_namespace"_a, "dialect_class"_a,
|
||||
"Testing hook for directly registering a dialect")
|
||||
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
|
||||
"operation_name"_a, "operation_class"_a, py::kw_only(),
|
||||
"operation_name"_a, "operation_class"_a, nb::kw_only(),
|
||||
"replace"_a = false,
|
||||
"Testing hook for directly registering an operation");
|
||||
|
||||
// Aside from making the globals accessible to python, having python manage
|
||||
// it is necessary to make sure it is destroyed (and releases its python
|
||||
// resources) properly.
|
||||
m.attr("globals") =
|
||||
py::cast(new PyGlobals, py::return_value_policy::take_ownership);
|
||||
m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership);
|
||||
|
||||
// Registration decorators.
|
||||
m.def(
|
||||
"register_dialect",
|
||||
[](py::type pyClass) {
|
||||
[](nb::type_object pyClass) {
|
||||
std::string dialectNamespace =
|
||||
pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
|
||||
nanobind::cast<std::string>(pyClass.attr("DIALECT_NAMESPACE"));
|
||||
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
|
||||
return pyClass;
|
||||
},
|
||||
@ -68,45 +69,46 @@ PYBIND11_MODULE(_mlir, m) {
|
||||
"Class decorator for registering a custom Dialect wrapper");
|
||||
m.def(
|
||||
"register_operation",
|
||||
[](const py::type &dialectClass, bool replace) -> py::cpp_function {
|
||||
return py::cpp_function(
|
||||
[dialectClass, replace](py::type opClass) -> py::type {
|
||||
[](const nb::type_object &dialectClass, bool replace) -> nb::object {
|
||||
return nb::cpp_function(
|
||||
[dialectClass,
|
||||
replace](nb::type_object opClass) -> nb::type_object {
|
||||
std::string operationName =
|
||||
opClass.attr("OPERATION_NAME").cast<std::string>();
|
||||
nanobind::cast<std::string>(opClass.attr("OPERATION_NAME"));
|
||||
PyGlobals::get().registerOperationImpl(operationName, opClass,
|
||||
replace);
|
||||
|
||||
// Dict-stuff the new opClass by name onto the dialect class.
|
||||
py::object opClassName = opClass.attr("__name__");
|
||||
nb::object opClassName = opClass.attr("__name__");
|
||||
dialectClass.attr(opClassName) = opClass;
|
||||
return opClass;
|
||||
});
|
||||
},
|
||||
"dialect_class"_a, py::kw_only(), "replace"_a = false,
|
||||
"dialect_class"_a, nb::kw_only(), "replace"_a = false,
|
||||
"Produce a class decorator for registering an Operation class as part of "
|
||||
"a dialect");
|
||||
m.def(
|
||||
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
|
||||
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
|
||||
return py::cpp_function([mlirTypeID,
|
||||
replace](py::object typeCaster) -> py::object {
|
||||
[](MlirTypeID mlirTypeID, bool replace) -> nb::object {
|
||||
return nb::cpp_function([mlirTypeID, replace](
|
||||
nb::callable typeCaster) -> nb::object {
|
||||
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
|
||||
return typeCaster;
|
||||
});
|
||||
},
|
||||
"typeid"_a, py::kw_only(), "replace"_a = false,
|
||||
"typeid"_a, nb::kw_only(), "replace"_a = false,
|
||||
"Register a type caster for casting MLIR types to custom user types.");
|
||||
m.def(
|
||||
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
|
||||
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
|
||||
return py::cpp_function(
|
||||
[mlirTypeID, replace](py::object valueCaster) -> py::object {
|
||||
[](MlirTypeID mlirTypeID, bool replace) -> nb::object {
|
||||
return nb::cpp_function(
|
||||
[mlirTypeID, replace](nb::callable valueCaster) -> nb::object {
|
||||
PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
|
||||
replace);
|
||||
return valueCaster;
|
||||
});
|
||||
},
|
||||
"typeid"_a, py::kw_only(), "replace"_a = false,
|
||||
"typeid"_a, nb::kw_only(), "replace"_a = false,
|
||||
"Register a value caster for casting MLIR values to custom user values.");
|
||||
|
||||
// Define and populate IR submodule.
|
||||
|
@ -1,4 +1,5 @@
|
||||
//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
|
||||
//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++
|
||||
//-*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
@ -9,13 +10,21 @@
|
||||
#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
|
||||
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include "mlir-c/Support.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/DataTypes.h"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
template <>
|
||||
struct std::iterator_traits<nanobind::detail::fast_iterator> {
|
||||
using value_type = nanobind::handle;
|
||||
using reference = const value_type;
|
||||
using pointer = void;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
};
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
@ -54,14 +63,14 @@ private:
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
||||
namespace pybind11 {
|
||||
namespace nanobind {
|
||||
namespace detail {
|
||||
|
||||
template <typename DefaultingTy>
|
||||
struct MlirDefaultingCaster {
|
||||
PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
|
||||
NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription));
|
||||
|
||||
bool load(pybind11::handle src, bool) {
|
||||
bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) {
|
||||
if (src.is_none()) {
|
||||
// Note that we do want an exception to propagate from here as it will be
|
||||
// the most informative.
|
||||
@ -76,20 +85,20 @@ struct MlirDefaultingCaster {
|
||||
// code to produce nice error messages (other than "Cannot cast...").
|
||||
try {
|
||||
value = DefaultingTy{
|
||||
pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
|
||||
nanobind::cast<typename DefaultingTy::ReferrentTy &>(src)};
|
||||
return true;
|
||||
} catch (std::exception &) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static handle cast(DefaultingTy src, return_value_policy policy,
|
||||
handle parent) {
|
||||
return pybind11::cast(src, policy);
|
||||
static handle from_cpp(DefaultingTy src, rv_policy policy,
|
||||
cleanup_list *cleanup) noexcept {
|
||||
return nanobind::cast(src, policy);
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
} // namespace nanobind
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// Conversion utilities.
|
||||
@ -100,7 +109,7 @@ namespace mlir {
|
||||
/// Accumulates into a python string from a method that accepts an
|
||||
/// MlirStringCallback.
|
||||
struct PyPrintAccumulator {
|
||||
pybind11::list parts;
|
||||
nanobind::list parts;
|
||||
|
||||
void *getUserData() { return this; }
|
||||
|
||||
@ -108,15 +117,15 @@ struct PyPrintAccumulator {
|
||||
return [](MlirStringRef part, void *userData) {
|
||||
PyPrintAccumulator *printAccum =
|
||||
static_cast<PyPrintAccumulator *>(userData);
|
||||
pybind11::str pyPart(part.data,
|
||||
nanobind::str pyPart(part.data,
|
||||
part.length); // Decodes as UTF-8 by default.
|
||||
printAccum->parts.append(std::move(pyPart));
|
||||
};
|
||||
}
|
||||
|
||||
pybind11::str join() {
|
||||
pybind11::str delim("", 0);
|
||||
return delim.attr("join")(parts);
|
||||
nanobind::str join() {
|
||||
nanobind::str delim("", 0);
|
||||
return nanobind::cast<nanobind::str>(delim.attr("join")(parts));
|
||||
}
|
||||
};
|
||||
|
||||
@ -124,21 +133,21 @@ struct PyPrintAccumulator {
|
||||
/// or binary.
|
||||
class PyFileAccumulator {
|
||||
public:
|
||||
PyFileAccumulator(const pybind11::object &fileObject, bool binary)
|
||||
PyFileAccumulator(const nanobind::object &fileObject, bool binary)
|
||||
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
|
||||
|
||||
void *getUserData() { return this; }
|
||||
|
||||
MlirStringCallback getCallback() {
|
||||
return [](MlirStringRef part, void *userData) {
|
||||
pybind11::gil_scoped_acquire acquire;
|
||||
nanobind::gil_scoped_acquire acquire;
|
||||
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
|
||||
if (accum->binary) {
|
||||
// Note: Still has to copy and not avoidable with this API.
|
||||
pybind11::bytes pyBytes(part.data, part.length);
|
||||
nanobind::bytes pyBytes(part.data, part.length);
|
||||
accum->pyWriteFunction(pyBytes);
|
||||
} else {
|
||||
pybind11::str pyStr(part.data,
|
||||
nanobind::str pyStr(part.data,
|
||||
part.length); // Decodes as UTF-8 by default.
|
||||
accum->pyWriteFunction(pyStr);
|
||||
}
|
||||
@ -146,7 +155,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
pybind11::object pyWriteFunction;
|
||||
nanobind::object pyWriteFunction;
|
||||
bool binary;
|
||||
};
|
||||
|
||||
@ -163,17 +172,17 @@ struct PySinglePartStringAccumulator {
|
||||
assert(!accum->invoked &&
|
||||
"PySinglePartStringAccumulator called back multiple times");
|
||||
accum->invoked = true;
|
||||
accum->value = pybind11::str(part.data, part.length);
|
||||
accum->value = nanobind::str(part.data, part.length);
|
||||
};
|
||||
}
|
||||
|
||||
pybind11::str takeValue() {
|
||||
nanobind::str takeValue() {
|
||||
assert(invoked && "PySinglePartStringAccumulator not called back");
|
||||
return std::move(value);
|
||||
}
|
||||
|
||||
private:
|
||||
pybind11::str value;
|
||||
nanobind::str value;
|
||||
bool invoked = false;
|
||||
};
|
||||
|
||||
@ -208,7 +217,7 @@ private:
|
||||
template <typename Derived, typename ElementTy>
|
||||
class Sliceable {
|
||||
protected:
|
||||
using ClassTy = pybind11::class_<Derived>;
|
||||
using ClassTy = nanobind::class_<Derived>;
|
||||
|
||||
/// Transforms `index` into a legal value to access the underlying sequence.
|
||||
/// Returns <0 on failure.
|
||||
@ -237,7 +246,7 @@ protected:
|
||||
/// Returns the element at the given slice index. Supports negative indices
|
||||
/// by taking elements in inverse order. Returns a nullptr object if out
|
||||
/// of bounds.
|
||||
pybind11::object getItem(intptr_t index) {
|
||||
nanobind::object getItem(intptr_t index) {
|
||||
// Negative indices mean we count from the end.
|
||||
index = wrapIndex(index);
|
||||
if (index < 0) {
|
||||
@ -250,20 +259,20 @@ protected:
|
||||
->getRawElement(linearizeIndex(index))
|
||||
.maybeDownCast();
|
||||
else
|
||||
return pybind11::cast(
|
||||
return nanobind::cast(
|
||||
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
|
||||
}
|
||||
|
||||
/// Returns a new instance of the pseudo-container restricted to the given
|
||||
/// slice. Returns a nullptr object on failure.
|
||||
pybind11::object getItemSlice(PyObject *slice) {
|
||||
nanobind::object getItemSlice(PyObject *slice) {
|
||||
ssize_t start, stop, extraStep, sliceLength;
|
||||
if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
|
||||
&sliceLength) != 0) {
|
||||
PyErr_SetString(PyExc_IndexError, "index out of range");
|
||||
return {};
|
||||
}
|
||||
return pybind11::cast(static_cast<Derived *>(this)->slice(
|
||||
return nanobind::cast(static_cast<Derived *>(this)->slice(
|
||||
startIndex + start * step, sliceLength, step * extraStep));
|
||||
}
|
||||
|
||||
@ -279,7 +288,7 @@ public:
|
||||
// Negative indices mean we count from the end.
|
||||
index = wrapIndex(index);
|
||||
if (index < 0) {
|
||||
throw pybind11::index_error("index out of range");
|
||||
throw nanobind::index_error("index out of range");
|
||||
}
|
||||
|
||||
return static_cast<Derived *>(this)->getRawElement(linearizeIndex(index));
|
||||
@ -304,39 +313,38 @@ public:
|
||||
}
|
||||
|
||||
/// Binds the indexing and length methods in the Python class.
|
||||
static void bind(pybind11::module &m) {
|
||||
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
|
||||
pybind11::module_local())
|
||||
static void bind(nanobind::module_ &m) {
|
||||
auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
|
||||
.def("__add__", &Sliceable::dunderAdd);
|
||||
Derived::bindDerived(clazz);
|
||||
|
||||
// Manually implement the sequence protocol via the C API. We do this
|
||||
// because it is approx 4x faster than via pybind11, largely because that
|
||||
// because it is approx 4x faster than via nanobind, largely because that
|
||||
// formulation requires a C++ exception to be thrown to detect end of
|
||||
// sequence.
|
||||
// Since we are in a C-context, any C++ exception that happens here
|
||||
// will terminate the program. There is nothing in this implementation
|
||||
// that should throw in a non-terminal way, so we forgo further
|
||||
// exception marshalling.
|
||||
// See: https://github.com/pybind/pybind11/issues/2842
|
||||
// See: https://github.com/pybind/nanobind/issues/2842
|
||||
auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
|
||||
assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
|
||||
"must be heap type");
|
||||
heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
|
||||
auto self = pybind11::cast<Derived *>(rawSelf);
|
||||
auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
|
||||
return self->length;
|
||||
};
|
||||
// sq_item is called as part of the sequence protocol for iteration,
|
||||
// list construction, etc.
|
||||
heap_type->as_sequence.sq_item =
|
||||
+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
|
||||
auto self = pybind11::cast<Derived *>(rawSelf);
|
||||
auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
|
||||
return self->getItem(index).release().ptr();
|
||||
};
|
||||
// mp_subscript is used for both slices and integer lookups.
|
||||
heap_type->as_mapping.mp_subscript =
|
||||
+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
|
||||
auto self = pybind11::cast<Derived *>(rawSelf);
|
||||
auto self = nanobind::cast<Derived *>(nanobind::handle(rawSelf));
|
||||
Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
|
||||
if (!PyErr_Occurred()) {
|
||||
// Integer indexing.
|
@ -8,12 +8,16 @@
|
||||
|
||||
#include "Pass.h"
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/string.h>
|
||||
|
||||
#include "IRModule.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Pass.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace py::literals;
|
||||
namespace nb = nanobind;
|
||||
using namespace nb::literals;
|
||||
using namespace mlir;
|
||||
using namespace mlir::python;
|
||||
|
||||
@ -34,16 +38,15 @@ public:
|
||||
MlirPassManager get() { return passManager; }
|
||||
|
||||
void release() { passManager.ptr = nullptr; }
|
||||
pybind11::object getCapsule() {
|
||||
return py::reinterpret_steal<py::object>(
|
||||
mlirPythonPassManagerToCapsule(get()));
|
||||
nb::object getCapsule() {
|
||||
return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
|
||||
}
|
||||
|
||||
static pybind11::object createFromCapsule(pybind11::object capsule) {
|
||||
static nb::object createFromCapsule(nb::object capsule) {
|
||||
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
|
||||
if (mlirPassManagerIsNull(rawPm))
|
||||
throw py::error_already_set();
|
||||
return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
|
||||
throw nb::python_error();
|
||||
return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -53,22 +56,23 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Create the `mlir.passmanager` here.
|
||||
void mlir::python::populatePassManagerSubmodule(py::module &m) {
|
||||
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of the top-level PassManager
|
||||
//----------------------------------------------------------------------------
|
||||
py::class_<PyPassManager>(m, "PassManager", py::module_local())
|
||||
.def(py::init<>([](const std::string &anchorOp,
|
||||
DefaultingPyMlirContext context) {
|
||||
MlirPassManager passManager = mlirPassManagerCreateOnOperation(
|
||||
context->get(),
|
||||
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
|
||||
return new PyPassManager(passManager);
|
||||
}),
|
||||
"anchor_op"_a = py::str("any"), "context"_a = py::none(),
|
||||
"Create a new PassManager for the current (or provided) Context.")
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyPassManager::getCapsule)
|
||||
nb::class_<PyPassManager>(m, "PassManager")
|
||||
.def(
|
||||
"__init__",
|
||||
[](PyPassManager &self, const std::string &anchorOp,
|
||||
DefaultingPyMlirContext context) {
|
||||
MlirPassManager passManager = mlirPassManagerCreateOnOperation(
|
||||
context->get(),
|
||||
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
|
||||
new (&self) PyPassManager(passManager);
|
||||
},
|
||||
"anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
|
||||
"Create a new PassManager for the current (or provided) Context.")
|
||||
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
|
||||
.def("_testing_release", &PyPassManager::release,
|
||||
"Releases (leaks) the backing pass manager (testing)")
|
||||
@ -101,9 +105,9 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
|
||||
"print_before_all"_a = false, "print_after_all"_a = true,
|
||||
"print_module_scope"_a = false, "print_after_change"_a = false,
|
||||
"print_after_failure"_a = false,
|
||||
"large_elements_limit"_a = py::none(), "enable_debug_info"_a = false,
|
||||
"print_generic_op_form"_a = false,
|
||||
"tree_printing_dir_path"_a = py::none(),
|
||||
"large_elements_limit"_a.none() = nb::none(),
|
||||
"enable_debug_info"_a = false, "print_generic_op_form"_a = false,
|
||||
"tree_printing_dir_path"_a.none() = nb::none(),
|
||||
"Enable IR printing, default as mlir-print-ir-after-all.")
|
||||
.def(
|
||||
"enable_verifier",
|
||||
@ -121,10 +125,10 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
|
||||
mlirStringRefCreate(pipeline.data(), pipeline.size()),
|
||||
errorMsg.getCallback(), errorMsg.getUserData());
|
||||
if (mlirLogicalResultIsFailure(status))
|
||||
throw py::value_error(std::string(errorMsg.join()));
|
||||
throw nb::value_error(errorMsg.join().c_str());
|
||||
return new PyPassManager(passManager);
|
||||
},
|
||||
"pipeline"_a, "context"_a = py::none(),
|
||||
"pipeline"_a, "context"_a.none() = nb::none(),
|
||||
"Parse a textual pass-pipeline and return a top-level PassManager "
|
||||
"that can be applied on a Module. Throw a ValueError if the pipeline "
|
||||
"can't be parsed")
|
||||
@ -137,7 +141,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) {
|
||||
mlirStringRefCreate(pipeline.data(), pipeline.size()),
|
||||
errorMsg.getCallback(), errorMsg.getUserData());
|
||||
if (mlirLogicalResultIsFailure(status))
|
||||
throw py::value_error(std::string(errorMsg.join()));
|
||||
throw nb::value_error(errorMsg.join().c_str());
|
||||
},
|
||||
"pipeline"_a,
|
||||
"Add textual pipeline elements to the pass manager. Throws a "
|
||||
|
@ -9,12 +9,12 @@
|
||||
#ifndef MLIR_BINDINGS_PYTHON_PASS_H
|
||||
#define MLIR_BINDINGS_PYTHON_PASS_H
|
||||
|
||||
#include "PybindUtils.h"
|
||||
#include "NanobindUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
void populatePassManagerSubmodule(pybind11::module &m);
|
||||
void populatePassManagerSubmodule(nanobind::module_ &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
@ -8,14 +8,16 @@
|
||||
|
||||
#include "Rewrite.h"
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
|
||||
#include "IRModule.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/Rewrite.h"
|
||||
#include "mlir/Config/mlir-config.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace nb = nanobind;
|
||||
using namespace mlir;
|
||||
using namespace py::literals;
|
||||
using namespace nb::literals;
|
||||
using namespace mlir::python;
|
||||
|
||||
namespace {
|
||||
@ -54,18 +56,17 @@ public:
|
||||
}
|
||||
MlirFrozenRewritePatternSet get() { return set; }
|
||||
|
||||
pybind11::object getCapsule() {
|
||||
return py::reinterpret_steal<py::object>(
|
||||
nb::object getCapsule() {
|
||||
return nb::steal<nb::object>(
|
||||
mlirPythonFrozenRewritePatternSetToCapsule(get()));
|
||||
}
|
||||
|
||||
static pybind11::object createFromCapsule(pybind11::object capsule) {
|
||||
static nb::object createFromCapsule(nb::object capsule) {
|
||||
MlirFrozenRewritePatternSet rawPm =
|
||||
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
|
||||
if (rawPm.ptr == nullptr)
|
||||
throw py::error_already_set();
|
||||
return py::cast(PyFrozenRewritePatternSet(rawPm),
|
||||
py::return_value_policy::move);
|
||||
throw nb::python_error();
|
||||
return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -75,25 +76,27 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Create the `mlir.rewrite` here.
|
||||
void mlir::python::populateRewriteSubmodule(py::module &m) {
|
||||
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
|
||||
//----------------------------------------------------------------------------
|
||||
// Mapping of the top-level PassManager
|
||||
//----------------------------------------------------------------------------
|
||||
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
||||
py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
|
||||
.def(py::init<>([](MlirModule module) {
|
||||
return mlirPDLPatternModuleFromModule(module);
|
||||
}),
|
||||
"module"_a, "Create a PDL module from the given module.")
|
||||
nb::class_<PyPDLPatternModule>(m, "PDLModule")
|
||||
.def(
|
||||
"__init__",
|
||||
[](PyPDLPatternModule &self, MlirModule module) {
|
||||
new (&self)
|
||||
PyPDLPatternModule(mlirPDLPatternModuleFromModule(module));
|
||||
},
|
||||
"module"_a, "Create a PDL module from the given module.")
|
||||
.def("freeze", [](PyPDLPatternModule &self) {
|
||||
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
|
||||
mlirRewritePatternSetFromPDLPatternModule(self.get())));
|
||||
});
|
||||
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
|
||||
py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
|
||||
py::module_local())
|
||||
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyFrozenRewritePatternSet::getCapsule)
|
||||
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
||||
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
|
||||
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
|
||||
&PyFrozenRewritePatternSet::getCapsule)
|
||||
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
|
||||
&PyFrozenRewritePatternSet::createFromCapsule);
|
||||
m.def(
|
||||
@ -102,7 +105,7 @@ void mlir::python::populateRewriteSubmodule(py::module &m) {
|
||||
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
|
||||
if (mlirLogicalResultIsFailure(status))
|
||||
// FIXME: Not sure this is the right error to throw here.
|
||||
throw py::value_error("pattern application failed to converge");
|
||||
throw nb::value_error("pattern application failed to converge");
|
||||
},
|
||||
"module"_a, "set"_a,
|
||||
"Applys the given patterns to the given module greedily while folding "
|
||||
|
@ -9,12 +9,12 @@
|
||||
#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H
|
||||
#define MLIR_BINDINGS_PYTHON_REWRITE_H
|
||||
|
||||
#include "PybindUtils.h"
|
||||
#include "NanobindUtils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace python {
|
||||
|
||||
void populateRewriteSubmodule(pybind11::module &m);
|
||||
void populateRewriteSubmodule(nanobind::module_ &m);
|
||||
|
||||
} // namespace python
|
||||
} // namespace mlir
|
||||
|
@ -448,6 +448,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
|
||||
MODULE_NAME _mlir
|
||||
ADD_TO_PARENT MLIRPythonSources.Core
|
||||
ROOT_DIR "${PYTHON_SOURCE_DIR}"
|
||||
PYTHON_BINDINGS_LIBRARY nanobind
|
||||
SOURCES
|
||||
MainModule.cpp
|
||||
IRAffine.cpp
|
||||
@ -463,7 +464,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
|
||||
Globals.h
|
||||
IRModule.h
|
||||
Pass.h
|
||||
PybindUtils.h
|
||||
NanobindUtils.h
|
||||
Rewrite.h
|
||||
PRIVATE_LINK_LIBS
|
||||
LLVMSupport
|
||||
|
@ -1,4 +1,4 @@
|
||||
nanobind>=2.0, <3.0
|
||||
nanobind>=2.4, <3.0
|
||||
numpy>=1.19.5, <=2.1.2
|
||||
pybind11>=2.10.0, <=2.13.6
|
||||
PyYAML>=5.4.0, <=6.0.1
|
||||
|
@ -176,5 +176,6 @@ def testWalkSymbolTables():
|
||||
try:
|
||||
SymbolTable.walk_symbol_tables(m.operation, True, error_callback)
|
||||
except RuntimeError as e:
|
||||
# CHECK: GOT EXCEPTION: Exception raised in callback: AssertionError: Raised from python
|
||||
# CHECK: GOT EXCEPTION: Exception raised in callback:
|
||||
# CHECK: AssertionError: Raised from python
|
||||
print(f"GOT EXCEPTION: {e}")
|
||||
|
@ -168,9 +168,9 @@ maybe(
|
||||
http_archive,
|
||||
name = "nanobind",
|
||||
build_file = "@llvm-raw//utils/bazel/third_party_build:nanobind.BUILD",
|
||||
sha256 = "bfbfc7e5759f1669e4ddb48752b1ddc5647d1430e94614d6f8626df1d508e65a",
|
||||
strip_prefix = "nanobind-2.2.0",
|
||||
url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.2.0.tar.gz",
|
||||
sha256 = "bb35deaed7efac5029ed1e33880a415638352f757d49207a8e6013fefb6c49a7",
|
||||
strip_prefix = "nanobind-2.4.0",
|
||||
url = "https://github.com/wjakob/nanobind/archive/refs/tags/v2.4.0.tar.gz",
|
||||
)
|
||||
|
||||
load("@rules_python//python:repositories.bzl", "py_repositories", "python_register_toolchains")
|
||||
|
@ -1044,6 +1044,9 @@ cc_library(
|
||||
srcs = [":MLIRBindingsPythonSourceFiles"],
|
||||
copts = PYBIND11_COPTS,
|
||||
features = PYBIND11_FEATURES,
|
||||
includes = [
|
||||
"lib/Bindings/Python",
|
||||
],
|
||||
textual_hdrs = [":MLIRBindingsPythonCoreHeaders"],
|
||||
deps = [
|
||||
":CAPIAsync",
|
||||
@ -1051,11 +1054,11 @@ cc_library(
|
||||
":CAPIIR",
|
||||
":CAPIInterfaces",
|
||||
":CAPITransforms",
|
||||
":MLIRBindingsPythonHeadersAndDeps",
|
||||
":MLIRBindingsPythonNanobindHeadersAndDeps",
|
||||
":Support",
|
||||
":config",
|
||||
"//llvm:Support",
|
||||
"@pybind11",
|
||||
"@nanobind",
|
||||
"@rules_python//python/cc:current_py_cc_headers",
|
||||
],
|
||||
)
|
||||
@ -1065,17 +1068,20 @@ cc_library(
|
||||
srcs = [":MLIRBindingsPythonSourceFiles"],
|
||||
copts = PYBIND11_COPTS,
|
||||
features = PYBIND11_FEATURES,
|
||||
includes = [
|
||||
"lib/Bindings/Python",
|
||||
],
|
||||
textual_hdrs = [":MLIRBindingsPythonCoreHeaders"],
|
||||
deps = [
|
||||
":CAPIAsyncHeaders",
|
||||
":CAPIDebugHeaders",
|
||||
":CAPIIRHeaders",
|
||||
":CAPITransformsHeaders",
|
||||
":MLIRBindingsPythonHeaders",
|
||||
":MLIRBindingsPythonNanobindHeaders",
|
||||
":Support",
|
||||
":config",
|
||||
"//llvm:Support",
|
||||
"@pybind11",
|
||||
"@nanobind",
|
||||
"@rules_python//python/cc:current_py_cc_headers",
|
||||
],
|
||||
)
|
||||
@ -1108,6 +1114,7 @@ cc_binary(
|
||||
deps = [
|
||||
":MLIRBindingsPythonCore",
|
||||
":MLIRBindingsPythonHeadersAndDeps",
|
||||
"@nanobind",
|
||||
],
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user