Mehdi Amini dc43f78565 Add basic Python bindings for the PassManager and bind libTransforms
This only exposes the ability to round-trip a textual pipeline at the
moment.
To exercise it, we also bind the libTransforms in a new Python extension. This
does not include any interesting bindings, but it includes all the
mechanism to add separate native extensions and load them dynamically.
As such passes in libTransforms are only registered after `import
mlir.transforms`.
To support this global registration, the TableGen backend is also
extended to bind to the C API the group registration for passes.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D90819
2020-11-10 19:55:21 +00:00

220 lines
7.7 KiB
C++

//===- MainModule.cpp - Main pybind module --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <tuple>
#include "PybindUtils.h"
#include "Globals.h"
#include "IRModules.h"
#include "Pass.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
assert(!instance && "PyGlobals already constructed");
instance = this;
}
PyGlobals::~PyGlobals() { instance = nullptr; }
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
py::gil_scoped_acquire();
if (loadedDialectModulesCache.contains(dialectNamespace))
return;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded;
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
try {
py::gil_scoped_release();
loaded = py::module::import(moduleName.c_str());
} catch (py::error_already_set &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
continue;
} else {
throw;
}
}
break;
}
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
loadedDialectModulesCache.insert(dialectNamespace);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::gil_scoped_acquire();
py::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
dialectNamespace +
"' is already registered.");
}
found = std::move(pyClass);
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass,
py::object rawOpViewClass) {
py::gil_scoped_acquire();
py::object &found = operationClassMap[operationName];
if (found) {
throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
operationName +
"' is already registered.");
}
found = std::move(pyClass);
rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}
llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
py::gil_scoped_acquire();
loadDialectModule(dialectNamespace);
// Fast match against the class map first (common case).
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
dialectClassMap[dialectNamespace] = py::none();
return llvm::None;
}
llvm::Optional<pybind11::object>
PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
{
py::gil_scoped_acquire();
auto foundIt = rawOpViewClassMapCache.find(operationName);
if (foundIt != rawOpViewClassMapCache.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
}
// Not found. Load the dialect namespace.
auto split = operationName.split('.');
llvm::StringRef dialectNamespace = split.first;
loadDialectModule(dialectNamespace);
// Attempt to find from the canonical map and cache.
{
py::gil_scoped_acquire();
auto foundIt = rawOpViewClassMap.find(operationName);
if (foundIt != rawOpViewClassMap.end()) {
if (foundIt->second.is_none())
return llvm::None;
assert(foundIt->second && "py::object is defined");
// Positive cache.
rawOpViewClassMapCache[operationName] = foundIt->second;
return foundIt->second;
} else {
// Negative cache.
rawOpViewClassMap[operationName] = py::none();
return llvm::None;
}
}
}
void PyGlobals::clearImportCache() {
py::gil_scoped_acquire();
loadedDialectModulesCache.clear();
rawOpViewClassMapCache.clear();
}
// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------
PYBIND11_MODULE(_mlir, m) {
m.doc() = "MLIR Python Native Extension";
py::class_<PyGlobals>(m, "_Globals")
.def_property("dialect_search_modules",
&PyGlobals::getDialectSearchPrefixes,
&PyGlobals::setDialectSearchPrefixes)
.def("append_dialect_search_prefix",
[](PyGlobals &self, std::string moduleName) {
self.getDialectSearchPrefixes().push_back(std::move(moduleName));
self.clearImportCache();
})
.def("_register_dialect_impl", &PyGlobals::registerDialectImpl,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"Testing hook for directly registering an operation");
// Aside from making the globals accessible to python, having python manage
// it is necessary to make sure it is destroyed (and releases its python
// resources) properly.
m.attr("globals") =
py::cast(new PyGlobals, py::return_value_policy::take_ownership);
// Registration decorators.
m.def(
"register_dialect",
[](py::object pyClass) {
std::string dialectNamespace =
pyClass.attr("DIALECT_NAMESPACE").cast<std::string>();
PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass);
return pyClass;
},
"Class decorator for registering a custom Dialect wrapper");
m.def(
"register_operation",
[](py::object dialectClass) -> py::cpp_function {
return py::cpp_function(
[dialectClass](py::object opClass) -> py::object {
std::string operationName =
opClass.attr("OPERATION_NAME").cast<std::string>();
auto rawSubclass = PyOpView::createRawSubclass(opClass);
PyGlobals::get().registerOperationImpl(operationName, opClass,
rawSubclass);
// Dict-stuff the new opClass by name onto the dialect class.
py::object opClassName = opClass.attr("__name__");
dialectClass.attr(opClassName) = opClass;
// Now create a special "Raw" subclass that passes through
// construction to the OpView parent (bypasses the intermediate
// child's __init__).
opClass.attr("_Raw") = rawSubclass;
return opClass;
});
},
"Class decorator for registering a custom Operation wrapper");
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
populateIRSubmodule(irModule);
// Define and populate PassManager submodule.
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passModule);
}