Maksim Levental efd96afedf
[MLIR][Python] reland (narrower) type stub generation (#157930)
This a reland of https://github.com/llvm/llvm-project/pull/155741 which
was reverted at https://github.com/llvm/llvm-project/pull/157831. This
version is narrower in scope - it only turns on automatic stub
generation for `MLIRPythonExtension.Core._mlir` and **does not do
anything automatically**. Specifically, the only CMake code added to
`AddMLIRPython.cmake` is the `mlir_generate_type_stubs` function which
is then used only in a manual way. The API for
`mlir_generate_type_stubs` is:

```
Arguments:
  MODULE_NAME: The fully-qualified name of the extension module (used for importing in python).
  DEPENDS_TARGETS: List of targets these type stubs depend on being built; usually corresponding to the
    specific extension module (e.g., something like StandalonePythonModules.extension._standaloneDialectsNanobind.dso)
    and the core bindings extension module (e.g., something like StandalonePythonModules.extension._mlir.dso).
  OUTPUT_DIR: The root output directory to emit the type stubs into.
  OUTPUTS: List of expected outputs.
  DEPENDS_TARGET_SRC_DEPS: List of cpp sources for extension library (for generating a DEPFILE).
  IMPORT_PATHS: List of paths to add to PYTHONPATH for stubgen.
  PATTERN_FILE: (Optional) Pattern file (see https://nanobind.readthedocs.io/en/latest/typing.html#pattern-files).
Outputs:
  NB_STUBGEN_CUSTOM_TARGET: The target corresponding to generation which other targets can depend on.
```

Downstream users should use `mlir_generate_type_stubs` in coordination
with `declare_mlir_python_sources` to turn on stub generation for their
own downstream dialect extensions and upstream dialect extensions if
they so choose. Standalone example shows an example.

Note, downstream will also need to set
`-DMLIR_PYTHON_PACKAGE_PREFIX=...` correctly for their bindings.
2025-09-20 18:47:32 +00:00

136 lines
4.9 KiB
C++

//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Rewrite.h"
#include "IRModule.h"
#include "mlir-c/Rewrite.h"
// clang-format off
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
// clang-format on
#include "mlir/Config/mlir-config.h"
namespace nb = nanobind;
using namespace mlir;
using namespace nb::literals;
using namespace mlir::python;
namespace {
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
/// Owning Wrapper around a PDLPatternModule.
class PyPDLPatternModule {
public:
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
: module(other.module) {
other.module.ptr = nullptr;
}
~PyPDLPatternModule() {
if (module.ptr != nullptr)
mlirPDLPatternModuleDestroy(module);
}
MlirPDLPatternModule get() { return module; }
private:
MlirPDLPatternModule module;
};
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
/// Owning Wrapper around a FrozenRewritePatternSet.
class PyFrozenRewritePatternSet {
public:
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
: set(other.set) {
other.set.ptr = nullptr;
}
~PyFrozenRewritePatternSet() {
if (set.ptr != nullptr)
mlirFrozenRewritePatternSetDestroy(set);
}
MlirFrozenRewritePatternSet get() { return set; }
nb::object getCapsule() {
return nb::steal<nb::object>(
mlirPythonFrozenRewritePatternSetToCapsule(get()));
}
static nb::object createFromCapsule(nb::object capsule) {
MlirFrozenRewritePatternSet rawPm =
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
if (rawPm.ptr == nullptr)
throw nb::python_error();
return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move);
}
private:
MlirFrozenRewritePatternSet set;
};
} // namespace
/// Create the `mlir.rewrite` here.
void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyPDLPatternModule>(m, "PDLModule")
.def(
"__init__",
[](PyPDLPatternModule &self, PyModule &module) {
new (&self) PyPDLPatternModule(
mlirPDLPatternModuleFromModule(module.get()));
},
// clang-format off
nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"),
// clang-format on
"module"_a, "Create a PDL module from the given module.")
.def("freeze", [](PyPDLPatternModule &self) {
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
mlirRewritePatternSetFromPDLPatternModule(self.get())));
});
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyFrozenRewritePatternSet::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
&PyFrozenRewritePatternSet::createFromCapsule);
m.def(
"apply_patterns_and_fold_greedily",
[](PyModule &module, PyFrozenRewritePatternSet &set) {
auto status =
mlirApplyPatternsAndFoldGreedily(module.get(), set.get(), {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error("pattern application failed to converge");
},
"module"_a, "set"_a,
// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given module greedily while folding "
"results.")
.def(
"apply_patterns_and_fold_greedily",
[](PyOperationBase &op, PyFrozenRewritePatternSet &set) {
auto status = mlirApplyPatternsAndFoldGreedilyWithOp(
op.getOperation(), set.get(), {});
if (mlirLogicalResultIsFailure(status))
throw std::runtime_error(
"pattern application failed to converge");
},
"op"_a, "set"_a,
// clang-format off
nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"),
// clang-format on
"Applys the given patterns to the given op greedily while folding "
"results.");
}