
This PR implements python enum bindings for *all* the enums - this includes `I*Attrs` (including positional/bit) and `Dialect/EnumAttr`.
There are a few parts to this:
1. CMake: a small addition to `declare_mlir_dialect_python_bindings` and `declare_mlir_dialect_extension_python_bindings` to generate the enum, a boolean arg `GEN_ENUM_BINDINGS` to make it opt-in (even though it works for basically all of the dialects), and an optional `GEN_ENUM_BINDINGS_TD_FILE` for handling corner cases.
2. EnumPythonBindingGen.cpp: there are two weedy aspects here that took investigation:
1. If an enum attribute is not a `Dialect/EnumAttr` then the `EnumAttrInfo` record is canonical, as far as both the cases of the enum **and the `AttrDefName`**. On the otherhand, if an enum is a `Dialect/EnumAttr` then the `EnumAttr` record has the correct `AttrDefName` ("load bearing", i.e., populates `ods.ir.AttributeBuilder('<NAME>')`) but its `enum` field contains the cases, which is an instance of `EnumAttrInfo`. The solution is to generate an one enum class for both `Dialect/EnumAttr` and "independent" `EnumAttrInfo` but to make that class interopable with two builder registrations that both do the right thing (see next sub-bullet).
2. Because we don't have a good connection to cpp `EnumAttr`, i.e., only the `enum class` getters are exposed (like `DimensionAttr::get(Dimension value)`), we have to resort to parsing e.g., `Attribute.parse(f'#gpu<dim {x}>')`. This means that the set of supported `assemblyFormat`s (for the enum) is fixed at compile of MLIR (currently 2, the only 2 I saw). There might be some things that could be done here but they would require quite a bit more C API work to support generically (e.g., casting ints to enum cases and binding all the getters or going generically through the `symbolize*` methods, like `symbolizeDimension(uint32_t)` or `symbolizeDimension(StringRef)`).
A few small changes:
1. In addition, since this patch registers default builders for attributes where people might've had their own builders already written, I added a `replace` param to `AttributeBuilder.insert` (`False` by default).
2. `makePythonEnumCaseName` can't handle all the different ways in which people write their enum cases, e.g., `llvm.CConv.Intel_OCL_BI`, which gets turned into `INTEL_O_C_L_B_I` (because `llvm::convertToSnakeFromCamelCase` doesn't look for runs of caps). So I dropped it. On the otherhand regularization does need to done because some enums have `None` as a case (and others might have other python keywords).
3. I turned on `llvm` dialect generation here in order to test `nvvm.WGMMAScaleIn`, which is an enum with [[ d7e26b5620/mlir/include/mlir/IR/EnumAttr.td (L22-L25)
| no explicit discriminator ]] for the `neg` case.
Note, dialects that didn't get a `GEN_ENUM_BINDINGS` don't have any enums to generate.
Let me know if I should add more tests (the three trivial ones I added exercise both the supported `assemblyFormat`s and `replace=True`).
Reviewed By: stellaraccident
Differential Revision: https://reviews.llvm.org/D157934
214 lines
7.2 KiB
C++
214 lines
7.2 KiB
C++
//===- IRModule.cpp - IR pybind module ------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "IRModule.h"
|
|
#include "Globals.h"
|
|
#include "PybindUtils.h"
|
|
|
|
#include <optional>
|
|
#include <vector>
|
|
|
|
#include "mlir-c/Bindings/Python/Interop.h"
|
|
#include "mlir-c/Support.h"
|
|
|
|
namespace py = pybind11;
|
|
using namespace mlir;
|
|
using namespace mlir::python;
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// PyGlobals
|
|
// -----------------------------------------------------------------------------
|
|
|
|
PyGlobals *PyGlobals::instance = nullptr;
|
|
|
|
PyGlobals::PyGlobals() {
|
|
assert(!instance && "PyGlobals already constructed");
|
|
instance = this;
|
|
// The default search path include {mlir.}dialects, where {mlir.} is the
|
|
// package prefix configured at compile time.
|
|
dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
|
|
}
|
|
|
|
PyGlobals::~PyGlobals() { instance = nullptr; }
|
|
|
|
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
|
|
if (loadedDialectModulesCache.contains(dialectNamespace))
|
|
return;
|
|
// Since re-entrancy is possible, make a copy of the search prefixes.
|
|
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
|
|
py::object loaded;
|
|
for (std::string moduleName : localSearchPrefixes) {
|
|
moduleName.push_back('.');
|
|
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
|
|
|
|
try {
|
|
loaded = py::module::import(moduleName.c_str());
|
|
} catch (py::error_already_set &e) {
|
|
if (e.matches(PyExc_ModuleNotFoundError)) {
|
|
continue;
|
|
}
|
|
throw;
|
|
}
|
|
break;
|
|
}
|
|
|
|
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
|
|
// may have occurred, which may do anything.
|
|
loadedDialectModulesCache.insert(dialectNamespace);
|
|
}
|
|
|
|
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
|
|
py::function pyFunc, bool replace) {
|
|
py::object &found = attributeBuilderMap[attributeKind];
|
|
if (found && !found.is_none() && !replace) {
|
|
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
|
|
attributeKind +
|
|
"' is already registered with func: " +
|
|
py::str(found).operator std::string())
|
|
.str());
|
|
}
|
|
found = std::move(pyFunc);
|
|
}
|
|
|
|
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
|
|
pybind11::function typeCaster,
|
|
bool replace) {
|
|
pybind11::object &found = typeCasterMap[mlirTypeID];
|
|
if (found && !found.is_none() && !replace)
|
|
throw std::runtime_error("Type caster is already registered");
|
|
found = std::move(typeCaster);
|
|
}
|
|
|
|
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
|
|
py::object pyClass) {
|
|
py::object &found = dialectClassMap[dialectNamespace];
|
|
if (found) {
|
|
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
|
|
dialectNamespace + "' is already registered.")
|
|
.str());
|
|
}
|
|
found = std::move(pyClass);
|
|
}
|
|
|
|
void PyGlobals::registerOperationImpl(const std::string &operationName,
|
|
py::object pyClass) {
|
|
py::object &found = operationClassMap[operationName];
|
|
if (found) {
|
|
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
|
|
"' is already registered.")
|
|
.str());
|
|
}
|
|
found = std::move(pyClass);
|
|
}
|
|
|
|
std::optional<py::function>
|
|
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
|
|
// Fast match against the class map first (common case).
|
|
const auto foundIt = attributeBuilderMap.find(attributeKind);
|
|
if (foundIt != attributeBuilderMap.end()) {
|
|
if (foundIt->second.is_none())
|
|
return std::nullopt;
|
|
assert(foundIt->second && "py::function is defined");
|
|
return foundIt->second;
|
|
}
|
|
|
|
// Not found and loading did not yield a registration. Negative cache.
|
|
attributeBuilderMap[attributeKind] = py::none();
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
|
|
MlirDialect dialect) {
|
|
{
|
|
// Fast match against the class map first (common case).
|
|
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
|
|
if (foundIt != typeCasterMapCache.end()) {
|
|
if (foundIt->second.is_none())
|
|
return std::nullopt;
|
|
assert(foundIt->second && "py::function is defined");
|
|
return foundIt->second;
|
|
}
|
|
}
|
|
|
|
// Not found. Load the dialect namespace.
|
|
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
|
|
|
|
// Attempt to find from the canonical map and cache.
|
|
{
|
|
const auto foundIt = typeCasterMap.find(mlirTypeID);
|
|
if (foundIt != typeCasterMap.end()) {
|
|
if (foundIt->second.is_none())
|
|
return std::nullopt;
|
|
assert(foundIt->second && "py::object is defined");
|
|
// Positive cache.
|
|
typeCasterMapCache[mlirTypeID] = foundIt->second;
|
|
return foundIt->second;
|
|
}
|
|
// Negative cache.
|
|
typeCasterMap[mlirTypeID] = py::none();
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
std::optional<py::object>
|
|
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
|
|
loadDialectModule(dialectNamespace);
|
|
// Fast match against the class map first (common case).
|
|
const auto foundIt = dialectClassMap.find(dialectNamespace);
|
|
if (foundIt != dialectClassMap.end()) {
|
|
if (foundIt->second.is_none())
|
|
return std::nullopt;
|
|
assert(foundIt->second && "py::object is defined");
|
|
return foundIt->second;
|
|
}
|
|
|
|
// Not found and loading did not yield a registration. Negative cache.
|
|
dialectClassMap[dialectNamespace] = py::none();
|
|
return std::nullopt;
|
|
}
|
|
|
|
std::optional<pybind11::object>
|
|
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
|
|
{
|
|
auto foundIt = operationClassMapCache.find(operationName);
|
|
if (foundIt != operationClassMapCache.end()) {
|
|
if (foundIt->second.is_none())
|
|
return std::nullopt;
|
|
assert(foundIt->second && "py::object is defined");
|
|
return foundIt->second;
|
|
}
|
|
}
|
|
|
|
// Not found. Load the dialect namespace.
|
|
auto split = operationName.split('.');
|
|
llvm::StringRef dialectNamespace = split.first;
|
|
loadDialectModule(dialectNamespace);
|
|
|
|
// Attempt to find from the canonical map and cache.
|
|
{
|
|
auto foundIt = operationClassMap.find(operationName);
|
|
if (foundIt != operationClassMap.end()) {
|
|
if (foundIt->second.is_none())
|
|
return std::nullopt;
|
|
assert(foundIt->second && "py::object is defined");
|
|
// Positive cache.
|
|
operationClassMapCache[operationName] = foundIt->second;
|
|
return foundIt->second;
|
|
}
|
|
// Negative cache.
|
|
operationClassMap[operationName] = py::none();
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
|
|
void PyGlobals::clearImportCache() {
|
|
loadedDialectModulesCache.clear();
|
|
operationClassMapCache.clear();
|
|
typeCasterMapCache.clear();
|
|
}
|