Maksim Levental f0ef5dba6d
[mlir][Python] create MLIRPythonSupport (#171775)
# What

This PR adds a shared library `MLIRPythonSupport` which contains all of
the CRTP classes ike `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute`, as well as other useful code like `Defaulting*`
and etc enabling their reuse in downstream projects. Downstream projects
can now do

```c++
struct PyTestType : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyTestType> {
  ...
};

class PyTestAttr : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<PyTestAttr> {
  ...
}

NB_MODULE(_mlirPythonTestNanobind, m) {
  PyTestType::bind(m);
  PyTestAttr::bind(m);
}
```

instead of using the discordant alternative
`mlir_type_subclass`/`mlir_attr_subclass` (same goes for
`PyConcreteValue`/`mlir_value_subclass`).

# Why

This PR is mostly code motion (along with CMake) but before I describe
the changes I want to state the goals/benefits:

1. Currently upstream "core" extensions and "dialect" extensions ([all
of the `Dialect*` extensions
here](d7c734b5a1/mlir/lib/Bindings/Python))
are a two-tier system;
**a**. [core
extensions](https://github.com/llvm/llvm-project/blob/main/mlir/lib/Bindings/Python/IRTypes.cpp#L361)
enjoy first class support as far as type inference[^3], type stub
generation, and ease of implementation, while dialect extensions [have
poorer support](https://reviews.llvm.org/D150927), incorrect type stub
generation much more tedious (boilerplate) implementation;
**b**. Crucially, this two-tiered system is reflected in the fact that
**the two sets of types/attributes are not in the same Python object
hierarchy**. To wit: `isinstance(..., Type)` and `isinstance(...,
Attribute)` are not supported for the dialect extensions[^2];
**c**. Since these types are not exposed in public headers, downstream
users (dialect extensions or not) cannot write functions that overload
on e.g. `PyFloat8*Type` - that's quite a [useful
feature](fdbee98df8/cpp_ext/TorchOps.cpp (L29-L69))!
2. The dialect extensions incur a sizeable performance penalty relative
to the core extensions in that every single trip across the wire (either
`python->cpp` or `cpp->python`) requires work in addition to nanobind's
own casting/construction pipeline;
**a**. When going from `python->cpp`, [we extract the capsule object
from the Python
object](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h#L219C24-L219C46)
and then extract from the capsule the `Mlir*` opaque struct/ptr. This
side isn't so onerous;
**b**. When going from `cpp->python` we call long-hand call Python
`import` APIs and construct the Python object using `_CAPICreate`. Note,
there at least 2 `attr` calls incurred in addition to `_CAPICreate`;
this is already much more [efficiently handled by nanobind
itself](4ba51fcf79/src/nb_internals.h (L381-L382))!
3. This division blocks various features: in some configurations[^1] we
trigger a circular import bug because "dialect" types and attributes
perform an [import of the root `_mlir`
module](bd9651bf78/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h (L585))
when they are created (the types themselves, not even instances of those
types). This blocks type stub generation for dialect extensions (i.e.,
the reason we currently only generate type stubs for `_mlir`).

# How

Prior this was not done/possible because of "ODR" issues but I have
resolved those issues; the basic idea for how we solve this is "move
things we want to share into shared libraries":

1. Move IRCore (stuff like `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute`) into `MLIRPythonSupport`;
- Note, we move the rest of the things in `IRModule.h` (renamed to
`IRCore.h`) because `PyConcreteValue`, `PyConcreteType`,
`PyConcreteAttribute` depend on them. This makes for a bigger PR than
one would hope for but ultimately I think we should give people access
to these classes to use as they see fit (specifically inherit from, but
also liberally use in bindings signatures instead of the opaque `Mlir*`
struct wrappers).
2. Put all of this code into a nested namespace
`MLIR_BINDINGS_PYTHON_DOMAIN` which is determined by a compile time
define (and tied to `MLIR_BINDINGS_PYTHON_NB_DOMAIN`). This is necessary
in order to prevent conflicts on both symbol name **and** typeid
(necessary for nanobind to not double register binded types) between
multiple bindings libraries (e.g., `torch-mlir`, and `jax`). Note
[nanobind doesn't support `module_local` like
pybind11](https://nanobind.readthedocs.io/en/latest/porting.html#removed-features).
It does support `NB_DOMAIN` but that is not sufficient for
disambiguating typeids across projects (to wit: we currently define
`NB_DOMAIN` and it was still necessary to move everything to a nested
namespace);
3. Build the [nanobind library itself as a shared
object](https://github.com/wjakob/nanobind/blob/master/cmake/nanobind-config.cmake#L127)
(and link it to both the extensions and `MLIRPythonSupport`).
4. CMake to make this work, in-tree, out-of-tree, downstream, upstream,
etc.

# Testing

Three tests are added here 

1. `PythonTestModuleNanobind` is ported to use
`PyConcreteType<PyTestType>` instead of `mlir_type_subclass` and
`PyConcreteAttribute<PyTestAttr>` instead of `mlir_atrr_subclass`,
verifying this works for non-core extensions in-tree;
2. `StandaloneExtensionNanobind` is ported to use `struct PyCustomType :
mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<PyCustomType>`
instead of `mlir_type_subclass` verifying this works for non-core
extensions out-of-tree;
3. `StandaloneExtensionNanobind`'s `smoketest` is extended to also load
another bindings package (namely `mlir`) verifying
`MLIR_BINDINGS_PYTHON_DOMAIN` successfully disambiguates symbols and
typeids.

I have also tested this downstream:
https://github.com/llvm/eudsl/pull/287 as well run the following builder
bots:

mlir-nvidia-gcc7:
https://lab.llvm.org/buildbot/#/buildrequests/6654424?redirect_to_build=true

I have also tested against IREE:
https://github.com/iree-org/iree/pull/21916

# Integration

It is highly recommended to set the CMake var
`MLIR_BINDINGS_PYTHON_NB_DOMAIN` (which will also determine
`MLIR_BINDINGS_PYTHON_DOMAIN`) to something unique for each downstream.
This can also be passed explicitly to `add_mlir_python_modules` if your
project builds multiple bindings packages. I added a `WARNING` to this
effect in `AddMLIRPython.cmake`.

[^3]: Python values being typed correctly when exiting from cpp;
[^1]: Specifically when the modules are imported using `importlib`,
which occurs with nanobind's
[stubgen](https://github.com/wjakob/nanobind/blob/master/src/stubgen.py#L965);
[^2]: The workaround we implemented was a class method for the dialect
bindings called `Class.isinstance(...)`;
2026-01-05 09:08:13 -08:00

591 lines
23 KiB
C++

//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Rewrite.h"
#include "mlir-c/IR.h"
#include "mlir-c/Rewrite.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/IRCore.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir/Config/mlir-config.h"
#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
using namespace mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN;
namespace mlir {
namespace python {
namespace MLIR_BINDINGS_PYTHON_DOMAIN {
class PyPatternRewriter {
public:
PyPatternRewriter(MlirPatternRewriter rewriter)
: base(mlirPatternRewriterAsBase(rewriter)),
ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
PyInsertionPoint getInsertionPoint() const {
MlirBlock block = mlirRewriterBaseGetInsertionBlock(base);
MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base);
if (mlirOperationIsNull(op)) {
MlirOperation owner = mlirBlockGetParentOperation(block);
auto parent = PyOperation::forOperation(ctx, owner);
return PyInsertionPoint(PyBlock(parent, block));
}
return PyInsertionPoint(PyOperation::forOperation(ctx, op));
}
void replaceOp(MlirOperation op, MlirOperation newOp) {
mlirRewriterBaseReplaceOpWithOperation(base, op, newOp);
}
void replaceOp(MlirOperation op, const std::vector<MlirValue> &values) {
mlirRewriterBaseReplaceOpWithValues(base, op, values.size(), values.data());
}
void eraseOp(MlirOperation op) { mlirRewriterBaseEraseOp(base, op); }
private:
MlirRewriterBase base;
PyMlirContextRef ctx;
};
struct PyMlirPDLResultList : MlirPDLResultList {};
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
static nb::object objectFromPDLValue(MlirPDLValue value) {
if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v))
return nb::cast(v);
if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v))
return nb::cast(v);
if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v))
return nb::cast(v);
if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v))
return nb::cast(v);
throw std::runtime_error("unsupported PDL value type");
}
static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
MlirPDLValue *values) {
std::vector<nb::object> args;
args.reserve(nValues);
for (size_t i = 0; i < nValues; ++i)
args.push_back(objectFromPDLValue(values[i]));
return args;
}
// Convert the Python object to a boolean.
// If it evaluates to False, treat it as success;
// otherwise, treat it as failure.
// Note that None is considered success.
static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
if (obj.is_none())
return mlirLogicalResultSuccess();
return nb::cast<bool>(obj) ? mlirLogicalResultFailure()
: mlirLogicalResultSuccess();
}
/// Owning Wrapper around a PDLPatternModule.
class PyPDLPatternModule {
public:
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
: module(other.module) {
other.module.ptr = nullptr;
}
~PyPDLPatternModule() {
if (module.ptr != nullptr)
mlirPDLPatternModuleDestroy(module);
}
MlirPDLPatternModule get() { return module; }
void registerRewriteFunction(const std::string &name,
const nb::callable &fn) {
mlirPDLPatternModuleRegisterRewriteFunction(
get(), mlirStringRefCreate(name.data(), name.size()),
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
void registerConstraintFunction(const std::string &name,
const nb::callable &fn) {
mlirPDLPatternModuleRegisterConstraintFunction(
get(), mlirStringRefCreate(name.data(), name.size()),
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
size_t nValues, MlirPDLValue *values,
void *userData) -> MlirLogicalResult {
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
return logicalResultFromObject(
f(PyPatternRewriter(rewriter), PyMlirPDLResultList{results.ptr},
objectsFromPDLValues(nValues, values)));
},
fn.ptr());
}
private:
MlirPDLPatternModule module;
};
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
/// Owning Wrapper around a FrozenRewritePatternSet.
class PyFrozenRewritePatternSet {
public:
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
: set(other.set) {
other.set.ptr = nullptr;
}
~PyFrozenRewritePatternSet() {
if (set.ptr != nullptr)
mlirFrozenRewritePatternSetDestroy(set);
}
MlirFrozenRewritePatternSet get() { return set; }
nb::object getCapsule() {
return nb::steal<nb::object>(
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}
static nb::object createFromCapsule(const nb::object &capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
throw nb::python_error();
return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
}
private:
MlirFrozenRewritePatternSet set;
};
class PyRewritePatternSet {
public:
PyRewritePatternSet(MlirContext ctx)
: set(mlirRewritePatternSetCreate(ctx)), ctx(ctx) {}
~PyRewritePatternSet() {
if (set.ptr)
mlirRewritePatternSetDestroy(set);
}
void add(MlirStringRef rootName, unsigned benefit,
const nb::callable &matchAndRewrite) {
MlirRewritePatternCallbacks callbacks;
callbacks.construct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).inc_ref();
};
callbacks.destruct = [](void *userData) {
nb::handle(static_cast<PyObject *>(userData)).dec_ref();
};
callbacks.matchAndRewrite = [](MlirRewritePattern, MlirOperation op,
MlirPatternRewriter rewriter,
void *userData) -> MlirLogicalResult {
nb::handle f(static_cast<PyObject *>(userData));
PyMlirContextRef ctx =
PyMlirContext::forContext(mlirOperationGetContext(op));
nb::object opView = PyOperation::forOperation(ctx, op)->createOpView();
nb::object res = f(opView, PyPatternRewriter(rewriter));
return logicalResultFromObject(res);
};
MlirRewritePattern pattern = mlirOpRewritePatternCreate(
rootName, benefit, ctx, callbacks, matchAndRewrite.ptr(),
/* nGeneratedNames */ 0,
/* generatedNames */ nullptr);
mlirRewritePatternSetAdd(set, pattern);
}
PyFrozenRewritePatternSet freeze() {
MlirRewritePatternSet s = set;
set.ptr = nullptr;
return mlirFreezeRewritePattern(s);
}
private:
MlirRewritePatternSet set;
MlirContext ctx;
};
enum PyGreedyRewriteStrictness : std::underlying_type_t<
MlirGreedyRewriteStrictness> {
MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP = MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP,
MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS =
MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS,
MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS =
MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS,
};
enum PyGreedySimplifyRegionLevel : std::underlying_type_t<
MlirGreedySimplifyRegionLevel> {
MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED =
MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED,
MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL =
MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL,
MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE =
MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
};
/// Owning Wrapper around a GreedyRewriteDriverConfig.
class PyGreedyRewriteDriverConfig {
public:
PyGreedyRewriteDriverConfig()
: config(mlirGreedyRewriteDriverConfigCreate()) {}
PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
: config(other.config) {
other.config.ptr = nullptr;
}
~PyGreedyRewriteDriverConfig() {
if (config.ptr != nullptr)
mlirGreedyRewriteDriverConfigDestroy(config);
}
MlirGreedyRewriteDriverConfig get() { return config; }
void setMaxIterations(int64_t maxIterations) {
mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
}
void setMaxNumRewrites(int64_t maxNumRewrites) {
mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
}
void setUseTopDownTraversal(bool useTopDownTraversal) {
mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config,
useTopDownTraversal);
}
void enableFolding(bool enable) {
mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
}
void setStrictness(PyGreedyRewriteStrictness strictness) {
mlirGreedyRewriteDriverConfigSetStrictness(
config, static_cast<MlirGreedyRewriteStrictness>(strictness));
}
void setRegionSimplificationLevel(PyGreedySimplifyRegionLevel level) {
mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
config, static_cast<MlirGreedySimplifyRegionLevel>(level));
}
void enableConstantCSE(bool enable) {
mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
}
int64_t getMaxIterations() {
return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
}
int64_t getMaxNumRewrites() {
return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
}
bool getUseTopDownTraversal() {
return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
}
bool isFoldingEnabled() {
return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
}
PyGreedyRewriteStrictness getStrictness() {
return static_cast<PyGreedyRewriteStrictness>(
mlirGreedyRewriteDriverConfigGetStrictness(config));
}
PyGreedySimplifyRegionLevel getRegionSimplificationLevel() {
return static_cast<PyGreedySimplifyRegionLevel>(
mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
}
bool isConstantCSEEnabled() {
return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
}
private:
MlirGreedyRewriteDriverConfig config;
};
/// Create the `mlir.rewrite` here.
void populateRewriteSubmodule(nb::module_ &m) {
// Enum definitions
nb::enum_<PyGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
.value("ANY_OP", MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP)
.value("EXISTING_AND_NEW_OPS",
MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS)
.value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
nb::enum_<PyGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
.value("DISABLED", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED)
.value("NORMAL", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL)
.value("AGGRESSIVE", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE);
//----------------------------------------------------------------------------
// Mapping of the PatternRewriter
//----------------------------------------------------------------------------
nb::
class_<PyPatternRewriter>(m, "PatternRewriter")
.def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint,
"The current insertion point of the PatternRewriter.")
.def(
"replace_op",
[](PyPatternRewriter &self, MlirOperation op,
MlirOperation newOp) { self.replaceOp(op, newOp); },
"Replace an operation with a new operation.", nb::arg("op"),
nb::arg("new_op"),
// clang-format off
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", new_op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
// clang-format on
)
.def(
"replace_op",
[](PyPatternRewriter &self, MlirOperation op,
const std::vector<MlirValue> &values) {
self.replaceOp(op, values);
},
"Replace an operation with a list of values.", nb::arg("op"),
nb::arg("values"),
// clang-format off
nb::sig("def replace_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", values: list[" MAKE_MLIR_PYTHON_QUALNAME("ir.Value") "]) -> None")
// clang-format on
)
.def("erase_op", &PyPatternRewriter::eraseOp, "Erase an operation.",
nb::arg("op"),
// clang-format off
nb::sig("def erase_op(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ") -> None")
// clang-format on
);
//----------------------------------------------------------------------------
// Mapping of the RewritePatternSet
//----------------------------------------------------------------------------
nb::class_<PyRewritePatternSet>(m, "RewritePatternSet")
.def(
"__init__",
[](PyRewritePatternSet &self, DefaultingPyMlirContext context) {
new (&self) PyRewritePatternSet(context.get()->get());
},
"context"_a = nb::none())
.def(
"add",
[](PyRewritePatternSet &self, nb::handle root, const nb::callable &fn,
unsigned benefit) {
std::string opName;
if (root.is_type()) {
opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
} else if (nb::isinstance<nb::str>(root)) {
opName = nb::cast<std::string>(root);
} else {
throw nb::type_error(
"the root argument must be a type or a string");
}
self.add(mlirStringRefCreate(opName.data(), opName.size()), benefit,
fn);
},
"root"_a, "fn"_a, "benefit"_a = 1,
// clang-format off
nb::sig("def add(self, root: type | str, fn: typing.Callable[[" MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ", PatternRewriter], typing.Any], benefit: int = 1) -> None"),
// clang-format on
R"(
Add a new rewrite pattern on the specified root operation, using the provided callable
for matching and rewriting, and assign it the given benefit.
Args:
root: The root operation to which this pattern applies.
This may be either an OpView subclass (e.g., ``arith.AddIOp``) or
an operation name string (e.g., ``"arith.addi"``).
fn: The callable to use for matching and rewriting,
which takes an operation and a pattern rewriter as arguments.
The match is considered successful iff the callable returns
a value where ``bool(value)`` is ``False`` (e.g. ``None``).
If possible, the operation is cast to its corresponding OpView subclass
before being passed to the callable.
benefit: The benefit of the pattern, defaulting to 1.)")
.def("freeze", &PyRewritePatternSet::freeze,
"Freeze the pattern set into a frozen one.");
//----------------------------------------------------------------------------
// Mapping of the PDLResultList and PDLModule
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyMlirPDLResultList>(m, "PDLResultList")
.def(
"append",
[](PyMlirPDLResultList results, const PyValue &value) {
mlirPDLResultListPushBackValue(results, value);
},
// clang-format off
nb::sig("def append(self, value: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")")
// clang-format on
)
.def(
"append",
[](PyMlirPDLResultList results, const PyOperation &op) {
mlirPDLResultListPushBackOperation(results, op);
},
// clang-format off
nb::sig("def append(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")")
// clang-format on
)
.def(
"append",
[](PyMlirPDLResultList results, const PyType &type) {
mlirPDLResultListPushBackType(results, type);
},
// clang-format off
nb::sig("def append(self, type: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")")
// clang-format on
)
.def(
"append",
[](PyMlirPDLResultList results, const PyAttribute &attr) {
mlirPDLResultListPushBackAttribute(results, attr);
},
// clang-format off
nb::sig("def append(self, attr: " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")")
// clang-format on
);
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
[](PyPDLPatternModule &self, PyModule &module) {
new (&self) PyPDLPatternModule(
mlirPDLPatternModuleFromModule(module.get()));
},
// clang-format off
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
// clang-format on
"module"_a, "Create a PDL module from the given module.")
.def(
"__init__",
[](PyPDLPatternModule &self, PyModule &module) {
new (&self) PyPDLPatternModule(
mlirPDLPatternModuleFromModule(module.get()));
},
// clang-format off
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
// clang-format on
"module"_a, "Create a PDL module from the given module.")
.def(
"freeze",
[](PyPDLPatternModule &self) {
return PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
},
nb::keep_alive<0, 1>())
.def(
"register_rewrite_function",
[](PyPDLPatternModule &self, const std::string &name,
const nb::callable &fn) {
self.registerRewriteFunction(name, fn);
},
nb::keep_alive<1, 3>())
.def(
"register_constraint_function",
[](PyPDLPatternModule &self, const std::string &name,
const nb::callable &fn) {
self.registerConstraintFunction(name, fn);
},
nb::keep_alive<1, 3>());
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
.def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
.def_prop_rw("max_iterations",
&PyGreedyRewriteDriverConfig::getMaxIterations,
&PyGreedyRewriteDriverConfig::setMaxIterations,
"Maximum number of iterations")
.def_prop_rw("max_num_rewrites",
&PyGreedyRewriteDriverConfig::getMaxNumRewrites,
&PyGreedyRewriteDriverConfig::setMaxNumRewrites,
"Maximum number of rewrites per iteration")
.def_prop_rw("use_top_down_traversal",
&PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
&PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
"Whether to use top-down traversal")
.def_prop_rw("enable_folding",
&PyGreedyRewriteDriverConfig::isFoldingEnabled,
&PyGreedyRewriteDriverConfig::enableFolding,
"Enable or disable folding")
.def_prop_rw("strictness", &PyGreedyRewriteDriverConfig::getStrictness,
&PyGreedyRewriteDriverConfig::setStrictness,
"Rewrite strictness level")
.def_prop_rw("region_simplification_level",
&PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
&PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
"Region simplification level")
.def_prop_rw("enable_constant_cse",
&PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
&PyGreedyRewriteDriverConfig::enableConstantCSE,
"Enable or disable constant CSE");
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, PyFrozenRewritePatternSet &set) {
auto status = mlirApplyPatternsAndFoldGreedily(
module.get(), set.get(), mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
"module"_a, "set"_a,
// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(),
mlirGreedyRewriteDriverConfigCreate());
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"op"_a, "set"_a,
// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.")
.def(
"walk_and_apply_patterns",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
mlirWalkAndApplyPatterns(op.getOperation(), set.get());
},
"op"_a, "set"_a,
// clang-format off
nb::sig("def walk_and_apply_patterns(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applies the given patterns to the given op by a fast walk-based "
"driver.");
}
} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
} // namespace python
} // namespace mlir