Rahul Kayaith 3ea4c5014d [mlir][python] Capture error diagnostics in exceptions
This updates most (all?) error-diagnostic-emitting python APIs to
capture error diagnostics and include them in the raised exception's
message:
```
>>> Operation.parse('"arith.addi"() : () -> ()'))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
mlir._mlir_libs.MLIRError: Unable to parse operation assembly:
error: "-":1:1: 'arith.addi' op requires one result
 note: "-":1:1: see current operation: "arith.addi"() : () -> ()
```

The diagnostic information is available on the exception for users who
may want to customize the error message:
```
>>> try:
...   Operation.parse('"arith.addi"() : () -> ()')
... except MLIRError as e:
...   print(e.message)
...   print(e.error_diagnostics)
...   print(e.error_diagnostics[0].message)
...
Unable to parse operation assembly
[<mlir._mlir_libs._mlir.ir.DiagnosticInfo object at 0x7fed32bd6b70>]
'arith.addi' op requires one result
```

Error diagnostics captured in exceptions aren't propagated to diagnostic
handlers, to avoid double-reporting of errors. The context-level
`emit_error_diagnostics` option can be used to revert to the old
behaviour, causing error diagnostics to be reported to handlers instead
of as part of exceptions.

API changes:
- `Operation.verify` now raises an exception on verification failure,
  instead of returning `false`
- The exception raised by the following methods has been changed to
  `MLIRError`:
  - `PassManager.run`
  - `{Module,Operation,Type,Attribute}.parse`
  - `{RankedTensorType,UnrankedTensorType}.get`
  - `{MemRefType,UnrankedMemRefType}.get`
  - `VectorType.get`
  - `FloatAttr.get`

closes #60595

depends on D144804, D143830

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D143869
2023-03-07 14:59:22 -05:00

143 lines
5.9 KiB
C++

//===- Pass.cpp - Pass Management -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Pass.h"
#include "IRModule.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Pass.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
namespace {
/// Owning Wrapper around a PassManager.
class PyPassManager {
public:
PyPassManager(MlirPassManager passManager) : passManager(passManager) {}
PyPassManager(PyPassManager &&other) : passManager(other.passManager) {
other.passManager.ptr = nullptr;
}
~PyPassManager() {
if (!mlirPassManagerIsNull(passManager))
mlirPassManagerDestroy(passManager);
}
MlirPassManager get() { return passManager; }
void release() { passManager.ptr = nullptr; }
pybind11::object getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonPassManagerToCapsule(get()));
}
static pybind11::object createFromCapsule(pybind11::object capsule) {
MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr());
if (mlirPassManagerIsNull(rawPm))
throw py::error_already_set();
return py::cast(PyPassManager(rawPm), py::return_value_policy::move);
}
private:
MlirPassManager passManager;
};
} // namespace
/// Create the `mlir.passmanager` here.
void mlir::python::populatePassManagerSubmodule(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
py::class_<PyPassManager>(m, "PassManager", py::module_local())
.def(py::init<>([](const std::string &anchorOp,
DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreateOnOperation(
context->get(),
mlirStringRefCreate(anchorOp.data(), anchorOp.size()));
return new PyPassManager(passManager);
}),
py::arg("anchor_op") = py::str("any"),
py::arg("context") = py::none(),
"Create a new PassManager for the current (or provided) Context.")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyPassManager::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule)
.def("_testing_release", &PyPassManager::release,
"Releases (leaks) the backing pass manager (testing)")
.def(
"enable_ir_printing",
[](PyPassManager &passManager) {
mlirPassManagerEnableIRPrinting(passManager.get());
},
"Enable mlir-print-ir-after-all.")
.def(
"enable_verifier",
[](PyPassManager &passManager, bool enable) {
mlirPassManagerEnableVerifier(passManager.get(), enable);
},
py::arg("enable"), "Enable / disable verify-each.")
.def_static(
"parse",
[](const std::string &pipeline, DefaultingPyMlirContext context) {
MlirPassManager passManager = mlirPassManagerCreate(context->get());
PyPrintAccumulator errorMsg;
MlirLogicalResult status = mlirParsePassPipeline(
mlirPassManagerGetAsOpPassManager(passManager),
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
return new PyPassManager(passManager);
},
py::arg("pipeline"), py::arg("context") = py::none(),
"Parse a textual pass-pipeline and return a top-level PassManager "
"that can be applied on a Module. Throw a ValueError if the pipeline "
"can't be parsed")
.def(
"add",
[](PyPassManager &passManager, const std::string &pipeline) {
PyPrintAccumulator errorMsg;
MlirLogicalResult status = mlirOpPassManagerAddPipeline(
mlirPassManagerGetAsOpPassManager(passManager.get()),
mlirStringRefCreate(pipeline.data(), pipeline.size()),
errorMsg.getCallback(), errorMsg.getUserData());
if (mlirLogicalResultIsFailure(status))
throw SetPyError(PyExc_ValueError, std::string(errorMsg.join()));
},
py::arg("pipeline"),
"Add textual pipeline elements to the pass manager. Throws a "
"ValueError if the pipeline can't be parsed.")
.def(
"run",
[](PyPassManager &passManager, PyOperationBase &op) {
PyMlirContext::ErrorCapture errors(op.getOperation().getContext());
MlirLogicalResult status = mlirPassManagerRunOnOp(
passManager.get(), op.getOperation().get());
if (mlirLogicalResultIsFailure(status))
throw MLIRError("Failure while executing pass pipeline",
errors.take());
},
py::arg("operation"),
"Run the pass manager on the provided operation, raising an "
"MLIRError on failure.")
.def(
"__str__",
[](PyPassManager &self) {
MlirPassManager passManager = self.get();
PyPrintAccumulator printAccum;
mlirPrintPassPipeline(
mlirPassManagerGetAsOpPassManager(passManager),
printAccum.getCallback(), printAccum.getUserData());
return printAccum.join();
},
"Print the textual representation for this PassManager, suitable to "
"be passed to `parse` for round-tripping.");
}