
Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially. This also exposes greedy rewrite driver. Only way currently to define patterns is via PDL (just to keep small). The creation of the PDL pattern module could be improved to avoid folks potentially accessing the module used to construct it post construction. No ergonomic work done yet. --------- Signed-off-by: Jacques Pienaar <jpienaar@google.com>
111 lines
3.8 KiB
C++
111 lines
3.8 KiB
C++
//===- Rewrite.cpp - Rewrite ----------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "Rewrite.h"
|
|
|
|
#include "IRModule.h"
|
|
#include "mlir-c/Bindings/Python/Interop.h"
|
|
#include "mlir-c/Rewrite.h"
|
|
#include "mlir/Config/mlir-config.h"
|
|
|
|
namespace py = pybind11;
|
|
using namespace mlir;
|
|
using namespace py::literals;
|
|
using namespace mlir::python;
|
|
|
|
namespace {
|
|
|
|
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
/// Owning Wrapper around a PDLPatternModule.
|
|
class PyPDLPatternModule {
|
|
public:
|
|
PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {}
|
|
PyPDLPatternModule(PyPDLPatternModule &&other) noexcept
|
|
: module(other.module) {
|
|
other.module.ptr = nullptr;
|
|
}
|
|
~PyPDLPatternModule() {
|
|
if (module.ptr != nullptr)
|
|
mlirPDLPatternModuleDestroy(module);
|
|
}
|
|
MlirPDLPatternModule get() { return module; }
|
|
|
|
private:
|
|
MlirPDLPatternModule module;
|
|
};
|
|
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
|
|
/// Owning Wrapper around a FrozenRewritePatternSet.
|
|
class PyFrozenRewritePatternSet {
|
|
public:
|
|
PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {}
|
|
PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept
|
|
: set(other.set) {
|
|
other.set.ptr = nullptr;
|
|
}
|
|
~PyFrozenRewritePatternSet() {
|
|
if (set.ptr != nullptr)
|
|
mlirFrozenRewritePatternSetDestroy(set);
|
|
}
|
|
MlirFrozenRewritePatternSet get() { return set; }
|
|
|
|
pybind11::object getCapsule() {
|
|
return py::reinterpret_steal<py::object>(
|
|
mlirPythonFrozenRewritePatternSetToCapsule(get()));
|
|
}
|
|
|
|
static pybind11::object createFromCapsule(pybind11::object capsule) {
|
|
MlirFrozenRewritePatternSet rawPm =
|
|
mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr());
|
|
if (rawPm.ptr == nullptr)
|
|
throw py::error_already_set();
|
|
return py::cast(PyFrozenRewritePatternSet(rawPm),
|
|
py::return_value_policy::move);
|
|
}
|
|
|
|
private:
|
|
MlirFrozenRewritePatternSet set;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
/// Create the `mlir.rewrite` here.
|
|
void mlir::python::populateRewriteSubmodule(py::module &m) {
|
|
//----------------------------------------------------------------------------
|
|
// Mapping of the top-level PassManager
|
|
//----------------------------------------------------------------------------
|
|
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
|
|
py::class_<PyPDLPatternModule>(m, "PDLModule", py::module_local())
|
|
.def(py::init<>([](MlirModule module) {
|
|
return mlirPDLPatternModuleFromModule(module);
|
|
}),
|
|
"module"_a, "Create a PDL module from the given module.")
|
|
.def("freeze", [](PyPDLPatternModule &self) {
|
|
return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern(
|
|
mlirRewritePatternSetFromPDLPatternModule(self.get())));
|
|
});
|
|
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg
|
|
py::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet",
|
|
py::module_local())
|
|
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
|
|
&PyFrozenRewritePatternSet::getCapsule)
|
|
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR,
|
|
&PyFrozenRewritePatternSet::createFromCapsule);
|
|
m.def(
|
|
"apply_patterns_and_fold_greedily",
|
|
[](MlirModule module, MlirFrozenRewritePatternSet set) {
|
|
auto status = mlirApplyPatternsAndFoldGreedily(module, set, {});
|
|
if (mlirLogicalResultIsFailure(status))
|
|
// FIXME: Not sure this is the right error to throw here.
|
|
throw py::value_error("pattern application failed to converge");
|
|
},
|
|
"module"_a, "set"_a,
|
|
"Applys the given patterns to the given module greedily while folding "
|
|
"results.");
|
|
}
|