This PR continues the work of https://github.com/llvm/llvm-project/pull/171775 by moving more useful types/attributes into MLIRPythonSupport. You can now do ```c++ struct PyTestIntegerRankedTensorType : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType< PyTestIntegerRankedTensorType, mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> struct PyTestTensorValue : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue< PyTestTensorValue> ``` instead of `mlir_type_subclass` and `mlir_value_subclass`; **specifically manual registration of the "value caster" via indirection through the Python interpreter is no longer necessary** . You can also now freely use all such types at the nanobind API level (e.g., overload based on `FP*`): ```c++ using mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN; standaloneM.def("print_fp_type", [](PyF16Type &) { nb::print("this is a fp16 type"); }); standaloneM.def("print_fp_type", [](PyF32Type &) { nb::print("this is a fp32 type"); }); standaloneM.def("print_fp_type", [](PyF64Type &) { nb::print("this is a fp64 type"); }); ``` Note, here we only port `PythonTestModuleNanobind` but there is a follow-up PR that ports **all** in-tree dialect extensions https://github.com/llvm/llvm-project/pull/174156 to use these. After that one we can soft deprecate `mlir_pure_subclass`. Note, depends on https://github.com/llvm/llvm-project/pull/171775
160 lines
5.9 KiB
C++
160 lines
5.9 KiB
C++
//===- PythonTestModuleNanobind.cpp - PythonTest dialect extension --------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
// This is the nanobind edition of the PythonTest dialect module.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PythonTestCAPI.h"
|
|
#include "mlir-c/BuiltinAttributes.h"
|
|
#include "mlir-c/BuiltinTypes.h"
|
|
#include "mlir-c/Diagnostics.h"
|
|
#include "mlir-c/IR.h"
|
|
#include "mlir/Bindings/Python/Diagnostics.h"
|
|
#include "mlir/Bindings/Python/IRCore.h"
|
|
#include "mlir/Bindings/Python/IRTypes.h"
|
|
#include "mlir/Bindings/Python/Nanobind.h"
|
|
#include "mlir/Bindings/Python/NanobindAdaptors.h"
|
|
#include "nanobind/nanobind.h"
|
|
|
|
namespace nb = nanobind;
|
|
using namespace mlir::python::nanobind_adaptors;
|
|
|
|
static bool mlirTypeIsARankedIntegerTensor(MlirType t) {
|
|
return mlirTypeIsARankedTensor(t) &&
|
|
mlirTypeIsAInteger(mlirShapedTypeGetElementType(t));
|
|
}
|
|
|
|
struct PyTestType
|
|
: mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAPythonTestTestType;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirPythonTestTestTypeGetTypeID;
|
|
static constexpr const char *pyClassName = "TestType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
|
|
context) {
|
|
return PyTestType(context->getRef(),
|
|
mlirPythonTestTestTypeGet(context.get()->get()));
|
|
},
|
|
nb::arg("context").none() = nb::none());
|
|
}
|
|
};
|
|
|
|
struct PyTestIntegerRankedTensorType
|
|
: mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
|
|
PyTestIntegerRankedTensorType,
|
|
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
|
|
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirRankedTensorTypeGetTypeID;
|
|
static constexpr const char *pyClassName = "TestIntegerRankedTensorType";
|
|
using PyConcreteType::PyConcreteType;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](std::vector<int64_t> shape, unsigned width,
|
|
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
|
|
ctx) {
|
|
MlirAttribute encoding = mlirAttributeGetNull();
|
|
return PyTestIntegerRankedTensorType(
|
|
ctx->getRef(),
|
|
mlirRankedTensorTypeGet(
|
|
shape.size(), shape.data(),
|
|
mlirIntegerTypeGet(ctx.get()->get(), width), encoding));
|
|
},
|
|
nb::arg("shape"), nb::arg("width"),
|
|
nb::arg("context").none() = nb::none());
|
|
}
|
|
};
|
|
|
|
struct PyTestTensorValue
|
|
: mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
|
|
PyTestTensorValue> {
|
|
static constexpr IsAFunctionTy isaFunction =
|
|
mlirTypeIsAPythonTestTestTensorValue;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirRankedTensorTypeGetTypeID;
|
|
static constexpr const char *pyClassName = "TestTensorValue";
|
|
using PyConcreteValue::PyConcreteValue;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
|
|
}
|
|
};
|
|
|
|
class PyTestAttr
|
|
: public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
|
|
PyTestAttr> {
|
|
public:
|
|
static constexpr IsAFunctionTy isaFunction =
|
|
mlirAttributeIsAPythonTestTestAttribute;
|
|
static constexpr const char *pyClassName = "TestAttr";
|
|
using PyConcreteAttribute::PyConcreteAttribute;
|
|
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
|
|
mlirPythonTestTestAttributeGetTypeID;
|
|
|
|
static void bindDerived(ClassTy &c) {
|
|
c.def_static(
|
|
"get",
|
|
[](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
|
|
context) {
|
|
return PyTestAttr(context->getRef(), mlirPythonTestTestAttributeGet(
|
|
context.get()->get()));
|
|
},
|
|
nb::arg("context").none() = nb::none());
|
|
}
|
|
};
|
|
|
|
NB_MODULE(_mlirPythonTestNanobind, m) {
|
|
m.def(
|
|
"register_python_test_dialect",
|
|
[](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
|
|
context,
|
|
bool load) {
|
|
MlirDialectHandle pythonTestDialect =
|
|
mlirGetDialectHandle__python_test__();
|
|
mlirDialectHandleRegisterDialect(pythonTestDialect,
|
|
context.get()->get());
|
|
if (load) {
|
|
mlirDialectHandleLoadDialect(pythonTestDialect, context.get()->get());
|
|
}
|
|
},
|
|
nb::arg("context").none() = nb::none(), nb::arg("load") = true);
|
|
|
|
m.def(
|
|
"register_dialect",
|
|
[](MlirDialectRegistry registry) {
|
|
MlirDialectHandle pythonTestDialect =
|
|
mlirGetDialectHandle__python_test__();
|
|
mlirDialectHandleInsertDialect(pythonTestDialect, registry);
|
|
},
|
|
nb::arg("registry"),
|
|
// clang-format off
|
|
nb::sig("def register_dialect(registry: " MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry") ") -> None"));
|
|
// clang-format on
|
|
|
|
m.def(
|
|
"test_diagnostics_with_errors_and_notes",
|
|
[](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
|
|
ctx) {
|
|
mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get());
|
|
mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
|
|
throw nb::value_error(handler.takeMessage().c_str());
|
|
},
|
|
nb::arg("context").none() = nb::none());
|
|
|
|
PyTestAttr::bind(m);
|
|
PyTestType::bind(m);
|
|
PyTestIntegerRankedTensorType::bind(m);
|
|
PyTestTensorValue::bind(m);
|
|
}
|