llvm-project/mlir/lib/Bindings/Python/TransformInterpreter.cpp
Maksim Levental b2a7369631
[MLIR][Python] remove liveOperations (#155114)
Historical context: `PyMlirContext::liveOperations` was an optimization
meant to cut down on the number of Python object allocations and
(partially) a mechanism for updating validity of ops after
transformation. E.g. during walking/transforming the AST. See original
patch [here](https://reviews.llvm.org/D87958).

Inspired by a
[renewed](https://github.com/llvm/llvm-project/pull/139721#issuecomment-3217131918)
interest in https://github.com/llvm/llvm-project/pull/139721 (which has
become a little stale...)

<p align="center">
<img width="504" height="375" alt="image"
src="https://github.com/user-attachments/assets/0daad562-d3d1-4876-8d01-5dba382ab186"
/>
</p>

In the previous go-around
(https://github.com/llvm/llvm-project/pull/92631) there were two issues
which have been resolved

1. ops that were "fetched" under a root op which has been transformed
are no longer reported as invalid. We simply "[formally
forbid](https://github.com/llvm/llvm-project/pull/92631#issuecomment-2119397018)"
this;
2. `Module._CAPICreate(module_capsule)` must now be followed by a
`module._clear_mlir_module()` to prevent double-freeing of the actual
`ModuleOp` object (i.e. calling the dtor on the
`OwningOpRef<ModuleOp>`):

     ```python
    module = ...
    module_dup = Module._CAPICreate(module._CAPIPtr)
    module._clear_mlir_module()
    ```
- **the alternative choice** here is to remove the `Module._CAPICreate`
API altogether and replace it with something like `Module._move(module)`
which will do both `Module._CAPICreate` and `module._clear_mlir_module`.

Note, the other approach I explored last year was a [weakref
system](https://github.com/llvm/llvm-project/pull/97340) for
`mlir::Operation` which would effectively hoist this `liveOperations`
thing into MLIR core. Possibly doable but I now believe it's a bad idea.

The other potentially breaking change is `is`, which checks object
equality rather than value equality, will now report `False` because we
are always allocating `new` Python objects (ie that's the whole point of
this change). Users wanting to check equality for `Operation` and
`Module` should use `==`.
2025-09-01 21:53:33 -07:00

106 lines
3.9 KiB
C++

//===- TransformInterpreter.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Pybind classes for the transform dialect interpreter.
//
//===----------------------------------------------------------------------===//
#include "mlir-c/Dialect/Transform/Interpreter.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
namespace {
struct PyMlirTransformOptions {
PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); };
PyMlirTransformOptions(PyMlirTransformOptions &&other) {
options = other.options;
other.options.ptr = nullptr;
}
PyMlirTransformOptions(const PyMlirTransformOptions &) = delete;
~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); }
MlirTransformOptions options;
};
} // namespace
static void populateTransformInterpreterSubmodule(nb::module_ &m) {
nb::class_<PyMlirTransformOptions>(m, "TransformOptions")
.def(nb::init<>())
.def_prop_rw(
"expensive_checks",
[](const PyMlirTransformOptions &self) {
return mlirTransformOptionsGetExpensiveChecksEnabled(self.options);
},
[](PyMlirTransformOptions &self, bool value) {
mlirTransformOptionsEnableExpensiveChecks(self.options, value);
})
.def_prop_rw(
"enforce_single_top_level_transform_op",
[](const PyMlirTransformOptions &self) {
return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(
self.options);
},
[](PyMlirTransformOptions &self, bool value) {
mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options,
value);
});
m.def(
"apply_named_sequence",
[](MlirOperation payloadRoot, MlirOperation transformRoot,
MlirOperation transformModule, const PyMlirTransformOptions &options) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(transformRoot));
// Calling back into Python to invalidate everything under the payload
// root. This is awkward, but we don't have access to PyMlirContext
// object here otherwise.
nb::object obj = nb::cast(payloadRoot);
MlirLogicalResult result = mlirTransformApplyNamedSequence(
payloadRoot, transformRoot, transformModule, options.options);
if (mlirLogicalResultIsSuccess(result))
return;
throw nb::value_error(
("Failed to apply named transform sequence.\nDiagnostic message " +
scope.takeMessage())
.c_str());
},
nb::arg("payload_root"), nb::arg("transform_root"),
nb::arg("transform_module"),
nb::arg("transform_options") = PyMlirTransformOptions());
m.def(
"copy_symbols_and_merge_into",
[](MlirOperation target, MlirOperation other) {
mlir::python::CollectDiagnosticsToStringScope scope(
mlirOperationGetContext(target));
MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other);
if (mlirLogicalResultIsFailure(result)) {
throw nb::value_error(
("Failed to merge symbols.\nDiagnostic message " +
scope.takeMessage())
.c_str());
}
},
nb::arg("target"), nb::arg("other"));
}
NB_MODULE(_mlirTransformInterpreter, m) {
m.doc() = "MLIR Transform dialect interpreter functionality.";
populateTransformInterpreterSubmodule(m);
}