llvm-project/mlir/test/python/lib/PythonTestModuleNanobind.cpp
Maksim Levental 18fc908566
[mlir][Python] move IRTypes and IRAttributes to MLIRPythonSupport (#174118)
This PR continues the work of
https://github.com/llvm/llvm-project/pull/171775 by moving more useful
types/attributes into MLIRPythonSupport.

You can now do 

```c++
struct PyTestIntegerRankedTensorType
    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
          PyTestIntegerRankedTensorType,
          mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType>
struct PyTestTensorValue
    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
          PyTestTensorValue>
```
instead of `mlir_type_subclass` and `mlir_value_subclass`;
**specifically manual registration of the "value caster" via indirection
through the Python interpreter is no longer necessary** . You can also
now freely use all such types at the nanobind API level (e.g., overload
based on `FP*`):

```c++
using mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
standaloneM.def("print_fp_type", [](PyF16Type &) { nb::print("this is a fp16 type"); });
standaloneM.def("print_fp_type", [](PyF32Type &) { nb::print("this is a fp32 type"); });
standaloneM.def("print_fp_type", [](PyF64Type &) { nb::print("this is a fp64 type"); });
```

Note, here we only port `PythonTestModuleNanobind` but there is a
follow-up PR that ports **all** in-tree dialect extensions
https://github.com/llvm/llvm-project/pull/174156 to use these. After
that one we can soft deprecate `mlir_pure_subclass`.

Note, depends on https://github.com/llvm/llvm-project/pull/171775
2026-01-05 09:34:58 -08:00

160 lines
5.9 KiB
C++

//===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This is the nanobind edition of the PythonTest dialect module.
//===----------------------------------------------------------------------===//
#include "PythonTestCAPI.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/IRCore.h"
#include "mlir/Bindings/Python/IRTypes.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
return mlirTypeIsARankedTensor(t) &&
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
}
struct PyTestType
: mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPythonTestTestTypeGetTypeID;
static constexpr const char *pyClassName = "TestType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
context) {
return PyTestType(context->getRef(),
mlirPythonTestTestTypeGet(context.get()->get()));
},
nb::arg("context").none() = nb::none());
}
};
struct PyTestIntegerRankedTensorType
: mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
PyTestIntegerRankedTensorType,
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirRankedTensorTypeGetTypeID;
static constexpr const char *pyClassName = "TestIntegerRankedTensorType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<int64_t> shape, unsigned width,
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
ctx) {
MlirAttribute encoding = mlirAttributeGetNull();
return PyTestIntegerRankedTensorType(
ctx->getRef(),
mlirRankedTensorTypeGet(
shape.size(), shape.data(),
mlirIntegerTypeGet(ctx.get()->get(), width), encoding));
},
nb::arg("shape"), nb::arg("width"),
nb::arg("context").none() = nb::none());
}
};
struct PyTestTensorValue
: mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
PyTestTensorValue> {
static constexpr IsAFunctionTy isaFunction =
mlirTypeIsAPythonTestTestTensorValue;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirRankedTensorTypeGetTypeID;
static constexpr const char *pyClassName = "TestTensorValue";
using PyConcreteValue::PyConcreteValue;
static void bindDerived(ClassTy &c) {
c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
}
};
class PyTestAttr
: public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
PyTestAttr> {
public:
static constexpr IsAFunctionTy isaFunction =
mlirAttributeIsAPythonTestTestAttribute;
static constexpr const char *pyClassName = "TestAttr";
using PyConcreteAttribute::PyConcreteAttribute;
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
mlirPythonTestTestAttributeGetTypeID;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
context) {
return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
context.get()->get()));
},
nb::arg("context").none() = nb::none());
}
};
NB_MODULE(_mlirPythonTestNanobind, m) {
m.def(
"register_python_test_dialect",
[](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
context,
bool load) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleRegisterDialect(pythonTestDialect,
context.get()->get());
if (load) {
mlirDialectHandleLoadDialect(pythonTestDialect, context.get()->get());
}
},
nb::arg("context").none() = nb::none(), nb::arg("load") = true);
m.def(
"register_dialect",
[](MlirDialectRegistry registry) {
MlirDialectHandle pythonTestDialect =
mlirGetDialectHandle__python_test__();
mlirDialectHandleInsertDialect(pythonTestDialect, registry);
},
nb::arg("registry"),
// clang-format off
nb::sig("def register_dialect(registry: " MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry") ") -> None"));
// clang-format on
m.def(
"test_diagnostics_with_errors_and_notes",
[](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
ctx) {
mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
throw nb::value_error(handler.takeMessage().c_str());
},
nb::arg("context").none() = nb::none());
PyTestAttr::bind(m);
PyTestType::bind(m);
PyTestIntegerRankedTensorType::bind(m);
PyTestTensorValue::bind(m);
}