Maksim Levental f0ef5dba6d
[mlir][Python] create MLIRPythonSupport (#171775)
# What

This PR adds a shared library `MLIRPythonSupport` which contains all of
the CRTP classes ike `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute`, as well as other useful code like `Defaulting*`
and etc enabling their reuse in downstream projects. Downstream projects
can now do

```c++
struct PyTestType : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
  ...
};

class PyTestAttr : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<PyTestAttr> {
  ...
}

NB_MODULE(_mlirPythonTestNanobind, m) {
  PyTestType::bind(m);
  PyTestAttr::bind(m);
}
```

instead of using the discordant alternative
`mlir_type_subclass`/`mlir_attr_subclass` (same goes for
`PyConcreteValue`/`mlir_value_subclass`).

# Why

This PR is mostly code motion (along with CMake) but before I describe
the changes I want to state the goals/benefits:

1. Currently upstream "core" extensions and "dialect" extensions ([all
of the `Dialect*` extensions
here](d7c734b5a1/mlir/lib/Bindings/Python))
are a two-tier system;
**a**. [core
extensions](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/IRTypes.cpp#L361)
enjoy first class support as far as type inference[^3], type stub
generation, and ease of implementation, while dialect extensions [have
poorer support](https://reviews.llvm.org/D150927), incorrect type stub
generation much more tedious (boilerplate) implementation;
**b**. Crucially, this two-tiered system is reflected in the fact that
**the two sets of types/attributes are not in the same Python object
hierarchy**. To wit: `isinstance(..., Type)` and `isinstance(...,
Attribute)` are not supported for the dialect extensions[^2];
**c**. Since these types are not exposed in public headers, downstream
users (dialect extensions or not) cannot write functions that overload
on e.g. `PyFloat8*Type` - that's quite a [useful
feature](fdbee98df8/cpp_ext/TorchOps.cpp (L29-L69))!
2. The dialect extensions incur a sizeable performance penalty relative
to the core extensions in that every single trip across the wire (either
`python->cpp` or `cpp->python`) requires work in addition to nanobind's
own casting/construction pipeline;
**a**. When going from `python->cpp`, [we extract the capsule object
from the Python
object](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h#L219C24-L219C46)
and then extract from the capsule the `Mlir*` opaque struct/ptr. This
side isn't so onerous;
**b**. When going from `cpp->python` we call long-hand call Python
`import` APIs and construct the Python object using `_CAPICreate`. Note,
there at least 2 `attr` calls incurred in addition to `_CAPICreate`;
this is already much more [efficiently handled by nanobind
itself](4ba51fcf79/src/nb_internals.h (L381-L382))!
3. This division blocks various features: in some configurations[^1] we
trigger a circular import bug because "dialect" types and attributes
perform an [import of the root `_mlir`
module](bd9651bf78/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h (L585))
when they are created (the types themselves, not even instances of those
types). This blocks type stub generation for dialect extensions (i.e.,
the reason we currently only generate type stubs for `_mlir`).

# How

Prior this was not done/possible because of "ODR" issues but I have
resolved those issues; the basic idea for how we solve this is "move
things we want to share into shared libraries":

1. Move IRCore (stuff like `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute`) into `MLIRPythonSupport`;
- Note, we move the rest of the things in `IRModule.h` (renamed to
`IRCore.h`) because `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute` depend on them. This makes for a bigger PR than
one would hope for but ultimately I think we should give people access
to these classes to use as they see fit (specifically inherit from, but
also liberally use in bindings signatures instead of the opaque `Mlir*`
struct wrappers).
2. Put all of this code into a nested namespace
`MLIR_BINDINGS_PYTHON_DOMAIN` which is determined by a compile time
define (and tied to `MLIR_BINDINGS_PYTHON_NB_DOMAIN`). This is necessary
in order to prevent conflicts on both symbol name **and** typeid
(necessary for nanobind to not double register binded types) between
multiple bindings libraries (e.g., `torch-mlir`, and `jax`). Note
[nanobind doesn't support `module_local` like
pybind11](https://nanobind.readthedocs.io/en/latest/porting.html#removed-features).
It does support `NB_DOMAIN` but that is not sufficient for
disambiguating typeids across projects (to wit: we currently define
`NB_DOMAIN` and it was still necessary to move everything to a nested
namespace);
3. Build the [nanobind library itself as a shared
object](https://github.com/wjakob/nanobind/blob/master/cmake/nanobind-config.cmake#L127)
(and link it to both the extensions and `MLIRPythonSupport`).
4. CMake to make this work, in-tree, out-of-tree, downstream, upstream,
etc.

# Testing

Three tests are added here 

1. `PythonTestModuleNanobind` is ported to use
`PyConcreteType<PyTestType>` instead of `mlir_type_subclass` and
`PyConcreteAttribute<PyTestAttr>` instead of `mlir_atrr_subclass`,
verifying this works for non-core extensions in-tree;
2. `StandaloneExtensionNanobind` is ported to use `struct PyCustomType :
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyCustomType>`
instead of `mlir_type_subclass` verifying this works for non-core
extensions out-of-tree;
3. `StandaloneExtensionNanobind`'s `smoketest` is extended to also load
another bindings package (namely `mlir`) verifying
`MLIR_BINDINGS_PYTHON_DOMAIN` successfully disambiguates symbols and
typeids.

I have also tested this downstream:
https://github.com/llvm/eudsl/pull/287 as well run the following builder
bots:

mlir-nvidia-gcc7:
https://lab.llvm.org/buildbot/#/buildrequests/6654424?redirect_to_build=true

I have also tested against IREE:
https://github.com/iree-org/iree/pull/21916

# Integration

It is highly recommended to set the CMake var
`MLIR_BINDINGS_PYTHON_NB_DOMAIN` (which will also determine
`MLIR_BINDINGS_PYTHON_DOMAIN`) to something unique for each downstream.
This can also be passed explicitly to `add_mlir_python_modules` if your
project builds multiple bindings packages. I added a `WARNING` to this
effect in `AddMLIRPython.cmake`.

[^3]: Python values being typed correctly when exiting from cpp;
[^1]: Specifically when the modules are imported using `importlib`,
which occurs with nanobind's
[stubgen](https://github.com/wjakob/nanobind/blob/master/src/stubgen.py#L965);
[^2]: The workaround we implemented was a class method for the dialect
bindings called `Class.isinstance(...)`;
2026-01-05 09:08:13 -08:00

283 lines
13 KiB
C++

//===- Pass.cpp - Pass Management -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Pass.h"
#include "mlir-c/Pass.h"
#include "mlir/Bindings/Python/Globals.h"
#include "mlir/Bindings/Python/IRCore.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
/// Owning Wrapper around a PassManager.
class PyPassManager {
public:
PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
PyPassManager(PyPassManager &&other) noexcept
: passManager(other.passManager) {
other.passManager.ptr = nullptr;
}
~PyPassManager() {
if (!mlirPassManagerIsNull(passManager))
mlirPassManagerDestroy(passManager);
}
MlirPassManager get() { return passManager; }
void release() { passManager.ptr = nullptr; }
nb::object getCapsule() {
return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
}
static nb::object createFromCapsule(const nb::object &capsule) {
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(rawPm))
throw nb::python_error();
return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
}
private:
MlirPassManager passManager;
};
enum PyMlirPassDisplayMode : std::underlying_type_t<MlirPassDisplayMode> {
MLIR_PASS_DISPLAY_MODE_LIST = MLIR_PASS_DISPLAY_MODE_LIST,
MLIR_PASS_DISPLAY_MODE_PIPELINE = MLIR_PASS_DISPLAY_MODE_PIPELINE
};
struct PyMlirExternalPass : MlirExternalPass {};
/// Create the `mlir.passmanager` here.
void populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of enumerated types
//----------------------------------------------------------------------------
nb::enum_<PyMlirPassDisplayMode>(m, "PassDisplayMode")
.value("LIST", MLIR_PASS_DISPLAY_MODE_LIST)
.value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE);
//----------------------------------------------------------------------------
// Mapping of MlirExternalPass
//----------------------------------------------------------------------------
nb::class_<PyMlirExternalPass>(m, "ExternalPass")
.def("signal_pass_failure", [](PyMlirExternalPass pass) {
mlirExternalPassSignalFailure(pass);
});
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
nb::class_<PyPassManager>(m, "PassManager")
.def(
"__init__",
[](PyPassManager &self, const std::string &anchorOp,
DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreateOnOperation(
context->get(),
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
new (&self) PyPassManager(passManager);
},
"anchor_op"_a = nb::str("any"), "context"_a = nb::none(),
// clang-format off
nb::sig("def __init__(self, anchor_op: str = 'any', context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> None"),
// clang-format on
"Create a new PassManager for the current (or provided) Context.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
.def("_testing_release", &PyPassManager::release,
"Releases (leaks) the backing pass manager (testing)")
.def(
"enable_ir_printing",
[](PyPassManager &passManager, bool printBeforeAll,
bool printAfterAll, bool printModuleScope, bool printAfterChange,
bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
std::optional<int64_t> largeResourceLimit, bool enableDebugInfo,
bool printGenericOpForm,
std::optional<std::string> optionalTreePrintingPath) {
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (largeElementsLimit) {
mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
*largeElementsLimit);
mlirOpPrintingFlagsElideLargeResourceString(flags,
*largeElementsLimit);
}
if (largeResourceLimit)
mlirOpPrintingFlagsElideLargeResourceString(flags,
*largeResourceLimit);
if (enableDebugInfo)
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
/*prettyForm=*/false);
if (printGenericOpForm)
mlirOpPrintingFlagsPrintGenericOpForm(flags);
std::string treePrintingPath = "";
if (optionalTreePrintingPath.has_value())
treePrintingPath = optionalTreePrintingPath.value();
mlirPassManagerEnableIRPrinting(
passManager.get(), printBeforeAll, printAfterAll,
printModuleScope, printAfterChange, printAfterFailure, flags,
mlirStringRefCreate(treePrintingPath.data(),
treePrintingPath.size()));
mlirOpPrintingFlagsDestroy(flags);
},
"print_before_all"_a = false, "print_after_all"_a = true,
"print_module_scope"_a = false, "print_after_change"_a = false,
"print_after_failure"_a = false,
"large_elements_limit"_a = nb::none(),
"large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false,
"print_generic_op_form"_a = false,
"tree_printing_dir_path"_a = nb::none(),
"Enable IR printing, default as mlir-print-ir-after-all.")
.def(
"enable_verifier",
[](PyPassManager &passManager, bool enable) {
mlirPassManagerEnableVerifier(passManager.get(), enable);
},
"enable"_a, "Enable / disable verify-each.")
.def(
"enable_timing",
[](PyPassManager &passManager) {
mlirPassManagerEnableTiming(passManager.get());
},
"Enable pass timing.")
.def(
"enable_statistics",
[](PyPassManager &passManager, PyMlirPassDisplayMode displayMode) {
mlirPassManagerEnableStatistics(
passManager.get(),
static_cast<MlirPassDisplayMode>(displayMode));
},
"displayMode"_a = MLIR_PASS_DISPLAY_MODE_PIPELINE,
"Enable pass statistics.")
.def_static(
"parse",
[](const std::string &pipeline, DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreate(context->get());
PyPrintAccumulator errorMsg;
MlirLogicalResult status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(passManager),
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw nb::value_error(errorMsg.join().c_str());
return new PyPassManager(passManager);
},
"pipeline"_a, "context"_a = nb::none(),
// clang-format off
nb::sig("def parse(pipeline: str, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> PassManager"),
// clang-format on
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
.def(
"add",
[](PyPassManager &passManager, const std::string &pipeline) {
PyPrintAccumulator errorMsg;
MlirLogicalResult status = mlirOpPassManagerAddPipeline(
mlirPassManagerGetAsOpPassManager(passManager.get()),
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw nb::value_error(errorMsg.join().c_str());
},
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"add",
[](PyPassManager &passManager, const nb::callable &run,
std::optional<std::string> &name, const std::string &argument,
const std::string &description, const std::string &opName) {
if (!name.has_value()) {
name = nb::cast<std::string>(
nb::borrow<nb::str>(run.attr("__name__")));
}
MlirTypeID passID = PyGlobals::get().allocateTypeID();
MlirExternalPassCallbacks callbacks;
callbacks.construct = [](void *obj) {
(void)nb::handle(static_cast<PyObject *>(obj)).inc_ref();
};
callbacks.destruct = [](void *obj) {
(void)nb::handle(static_cast<PyObject *>(obj)).dec_ref();
};
callbacks.initialize = nullptr;
callbacks.clone = [](void *) -> void * {
throw std::runtime_error("Cloning Python passes not supported");
};
callbacks.run = [](MlirOperation op, MlirExternalPass pass,
void *userData) {
nb::handle(static_cast<PyObject *>(userData))(
op, PyMlirExternalPass{pass.ptr});
};
auto externalPass = mlirCreateExternalPass(
passID, mlirStringRefCreate(name->data(), name->length()),
mlirStringRefCreate(argument.data(), argument.length()),
mlirStringRefCreate(description.data(), description.length()),
mlirStringRefCreate(opName.data(), opName.size()),
/*nDependentDialects*/ 0, /*dependentDialects*/ nullptr,
callbacks, /*userData*/ run.ptr());
mlirPassManagerAddOwnedPass(passManager.get(), externalPass);
},
"run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "",
"description"_a.none() = "", "op_name"_a.none() = "",
R"(
Add a python-defined pass to the current pipeline of the pass manager.
Args:
run: A callable with signature ``(op: ir.Operation, pass_: ExternalPass) -> None``.
Called when the pass executes. It receives the operation to be processed and
the current ``ExternalPass`` instance.
Use ``pass_.signal_pass_failure()`` to signal failure.
name: The name of the pass. Defaults to ``run.__name__``.
argument: The command-line argument for the pass. Defaults to empty.
description: The description of the pass. Defaults to empty.
op_name: The name of the operation this pass operates on.
It will be a generic operation pass if not specified.)")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op) {
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
if (mlirLogicalResultIsFailure(status))
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
"operation"_a,
// clang-format off
nb::sig("def run(self, operation: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ") -> None"),
// clang-format on
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
"__str__",
[](PyPassManager &self) {
MlirPassManager passManager = self.get();
PyPrintAccumulator printAccum;
mlirPrintPassPipeline(
mlirPassManagerGetAsOpPassManager(passManager),
printAccum.getCallback(), printAccum.getUserData());
return printAccum.join();
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir