Peter Hawkins 41bd35b58b
[mlir python] Port Python core code to nanobind. (#118583)
Why? https://nanobind.readthedocs.io/en/latest/why.html says it better
than I can, but my primary motivation for this change is to improve MLIR
IR construction time from JAX.

For a complicated Google-internal LLM model in JAX, this change improves
the MLIR
lowering time by around 5s (out of around 30s), which is a significant
speedup for simply switching binding frameworks.

To a large extent, this is a mechanical change, for instance changing
`pybind11::`
to `nanobind::`.

Notes:
* this PR needs Nanobind 2.4.0, because it needs a bug fix
(https://github.com/wjakob/nanobind/pull/806) that landed in that
release.
* this PR does not port the in-tree dialect extension modules. They can
be ported in a future PR.
* I removed the py::sibling() annotations from def_static and def_class
in `PybindAdapters.h`. These ask pybind11 to try to form an overload
with an existing method, but it's not possible to form mixed
pybind11/nanobind overloads this ways and the parent class is now
defined in nanobind. Better solutions may be possible here.
* nanobind does not contain an exact equivalent of pybind11's buffer
protocol support. It was not hard to add a nanobind implementation of a
similar API.
* nanobind is pickier about casting to std::vector<bool>, expecting that
the input is a sequence of bool types, not truthy values. In a couple of
places I added code to support truthy values during casting.
* nanobind distinguishes bytes (`nb::bytes`) from strings (e.g.,
`std::string`). This required nb::bytes overloads in a few places.
2024-12-18 11:16:11 -08:00

180 lines
7.7 KiB
C++

//===- Pass.cpp - Pass Management -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Pass.h"
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Pass.h"
namespace nb = nanobind;
using namespace nb::literals;
using namespace mlir;
using namespace mlir::python;
namespace {
/// Owning Wrapper around a PassManager.
class PyPassManager {
public:
PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
PyPassManager(PyPassManager &&other) noexcept
: passManager(other.passManager) {
other.passManager.ptr = nullptr;
}
~PyPassManager() {
if (!mlirPassManagerIsNull(passManager))
mlirPassManagerDestroy(passManager);
}
MlirPassManager get() { return passManager; }
void release() { passManager.ptr = nullptr; }
nb::object getCapsule() {
return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get()));
}
static nb::object createFromCapsule(nb::object capsule) {
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(rawPm))
throw nb::python_error();
return nb::cast(PyPassManager(rawPm), nb::rv_policy::move);
}
private:
MlirPassManager passManager;
};
} // namespace
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(nb::module_ &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
nb::class_<PyPassManager>(m, "PassManager")
.def(
"__init__",
[](PyPassManager &self, const std::string &anchorOp,
DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreateOnOperation(
context->get(),
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
new (&self) PyPassManager(passManager);
},
"anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(),
"Create a new PassManager for the current (or provided) Context.")
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
.def("_testing_release", &PyPassManager::release,
"Releases (leaks) the backing pass manager (testing)")
.def(
"enable_ir_printing",
[](PyPassManager &passManager, bool printBeforeAll,
bool printAfterAll, bool printModuleScope, bool printAfterChange,
bool printAfterFailure, std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool printGenericOpForm,
std::optional<std::string> optionalTreePrintingPath) {
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (largeElementsLimit)
mlirOpPrintingFlagsElideLargeElementsAttrs(flags,
*largeElementsLimit);
if (enableDebugInfo)
mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true,
/*prettyForm=*/false);
if (printGenericOpForm)
mlirOpPrintingFlagsPrintGenericOpForm(flags);
std::string treePrintingPath = "";
if (optionalTreePrintingPath.has_value())
treePrintingPath = optionalTreePrintingPath.value();
mlirPassManagerEnableIRPrinting(
passManager.get(), printBeforeAll, printAfterAll,
printModuleScope, printAfterChange, printAfterFailure, flags,
mlirStringRefCreate(treePrintingPath.data(),
treePrintingPath.size()));
mlirOpPrintingFlagsDestroy(flags);
},
"print_before_all"_a = false, "print_after_all"_a = true,
"print_module_scope"_a = false, "print_after_change"_a = false,
"print_after_failure"_a = false,
"large_elements_limit"_a.none() = nb::none(),
"enable_debug_info"_a = false, "print_generic_op_form"_a = false,
"tree_printing_dir_path"_a.none() = nb::none(),
"Enable IR printing, default as mlir-print-ir-after-all.")
.def(
"enable_verifier",
[](PyPassManager &passManager, bool enable) {
mlirPassManagerEnableVerifier(passManager.get(), enable);
},
"enable"_a, "Enable / disable verify-each.")
.def_static(
"parse",
[](const std::string &pipeline, DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreate(context->get());
PyPrintAccumulator errorMsg;
MlirLogicalResult status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(passManager),
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw nb::value_error(errorMsg.join().c_str());
return new PyPassManager(passManager);
},
"pipeline"_a, "context"_a.none() = nb::none(),
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
.def(
"add",
[](PyPassManager &passManager, const std::string &pipeline) {
PyPrintAccumulator errorMsg;
MlirLogicalResult status = mlirOpPassManagerAddPipeline(
mlirPassManagerGetAsOpPassManager(passManager.get()),
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw nb::value_error(errorMsg.join().c_str());
},
"pipeline"_a,
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op,
bool invalidateOps) {
if (invalidateOps) {
op.getOperation().getContext()->clearOperationsInside(op);
}
// Actually run the pass manager.
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
if (mlirLogicalResultIsFailure(status))
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
"operation"_a, "invalidate_ops"_a = true,
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
"__str__",
[](PyPassManager &self) {
MlirPassManager passManager = self.get();
PyPrintAccumulator printAccum;
mlirPrintPassPipeline(
mlirPassManagerGetAsOpPassManager(passManager),
printAccum.getCallback(), printAccum.getUserData());
return printAccum.join();
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
}