max 92233062c1 [mlir][python bindings] generate all the enums
This PR implements python enum bindings for *all* the enums - this includes `I*Attrs` (including positional/bit) and `Dialect/EnumAttr`.

There are a few parts to this:

1. CMake: a small addition to `declare_mlir_dialect_python_bindings` and `declare_mlir_dialect_extension_python_bindings` to generate the enum, a boolean arg `GEN_ENUM_BINDINGS` to make it opt-in (even though it works for basically all of the dialects), and an optional `GEN_ENUM_BINDINGS_TD_FILE` for handling corner cases.
2. EnumPythonBindingGen.cpp: there are two weedy aspects here that took investigation:
    1. If an enum attribute is not a `Dialect/EnumAttr` then the `EnumAttrInfo` record is canonical, as far as both the cases of the enum **and the `AttrDefName`**. On the otherhand, if an enum is a `Dialect/EnumAttr` then the `EnumAttr` record has the correct `AttrDefName` ("load bearing", i.e., populates `ods.ir.AttributeBuilder('<NAME>')`) but its `enum` field contains the cases, which is an instance of `EnumAttrInfo`. The solution is to generate an one enum class for both `Dialect/EnumAttr` and "independent" `EnumAttrInfo` but to make that class interopable with two builder registrations that both do the right thing (see next sub-bullet).
    2. Because we don't have a good connection to cpp `EnumAttr`, i.e., only the `enum class` getters are exposed (like `DimensionAttr::get(Dimension value)`), we have to resort to parsing e.g., `Attribute.parse(f'#gpu<dim {x}>')`. This means that the set of supported `assemblyFormat`s (for the enum) is fixed at compile of MLIR (currently 2, the only 2 I saw). There might be some things that could be done here but they would require quite a bit more C API work to support generically (e.g., casting ints to enum cases and binding all the getters or going generically through the `symbolize*` methods, like `symbolizeDimension(uint32_t)` or `symbolizeDimension(StringRef)`).

A few small changes:

1. In addition, since this patch registers default builders for attributes where people might've had their own builders already written, I added a `replace` param to `AttributeBuilder.insert` (`False` by default).
2. `makePythonEnumCaseName` can't handle all the different ways in which people write their enum cases, e.g., `llvm.CConv.Intel_OCL_BI`, which gets turned into `INTEL_O_C_L_B_I` (because `llvm::convertToSnakeFromCamelCase` doesn't look for runs of caps). So I dropped it. On the otherhand regularization does need to done because some enums have `None` as a case (and others might have other python keywords).
3. I turned on `llvm` dialect generation here in order to test `nvvm.WGMMAScaleIn`, which is an enum with [[ d7e26b5620/mlir/include/mlir/IR/EnumAttr.td (L22-L25) | no explicit discriminator ]] for the `neg` case.

Note, dialects that didn't get a `GEN_ENUM_BINDINGS` don't have any enums to generate.

Let me know if I should add more tests (the three trivial ones I added exercise both the supported `assemblyFormat`s and `replace=True`).

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D157934
2023-08-23 15:03:55 -05:00

214 lines
7.2 KiB
C++

//===- IRModule.cpp - IR pybind module ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "IRModule.h"
#include "Globals.h"
#include "PybindUtils.h"
#include <optional>
#include <vector>
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------
PyGlobals *PyGlobals::instance = nullptr;
PyGlobals::PyGlobals() {
assert(!instance && "PyGlobals already constructed");
instance = this;
// The default search path include {mlir.}dialects, where {mlir.} is the
// package prefix configured at compile time.
dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
}
PyGlobals::~PyGlobals() { instance = nullptr; }
void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
if (loadedDialectModulesCache.contains(dialectNamespace))
return;
// Since re-entrancy is possible, make a copy of the search prefixes.
std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
py::object loaded;
for (std::string moduleName : localSearchPrefixes) {
moduleName.push_back('.');
moduleName.append(dialectNamespace.data(), dialectNamespace.size());
try {
loaded = py::module::import(moduleName.c_str());
} catch (py::error_already_set &e) {
if (e.matches(PyExc_ModuleNotFoundError)) {
continue;
}
throw;
}
break;
}
// Note: Iterator cannot be shared from prior to loading, since re-entrancy
// may have occurred, which may do anything.
loadedDialectModulesCache.insert(dialectNamespace);
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
py::function pyFunc, bool replace) {
py::object &found = attributeBuilderMap[attributeKind];
if (found && !found.is_none() && !replace) {
throw std::runtime_error((llvm::Twine("Attribute builder for '") +
attributeKind +
"' is already registered with func: " +
py::str(found).operator std::string())
.str());
}
found = std::move(pyFunc);
}
void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
pybind11::function typeCaster,
bool replace) {
pybind11::object &found = typeCasterMap[mlirTypeID];
if (found && !found.is_none() && !replace)
throw std::runtime_error("Type caster is already registered");
found = std::move(typeCaster);
}
void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
if (found) {
throw std::runtime_error((llvm::Twine("Dialect namespace '") +
dialectNamespace + "' is already registered.")
.str());
}
found = std::move(pyClass);
}
void PyGlobals::registerOperationImpl(const std::string &operationName,
py::object pyClass) {
py::object &found = operationClassMap[operationName];
if (found) {
throw std::runtime_error((llvm::Twine("Operation '") + operationName +
"' is already registered.")
.str());
}
found = std::move(pyClass);
}
std::optional<py::function>
PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) {
// Fast match against the class map first (common case).
const auto foundIt = attributeBuilderMap.find(attributeKind);
if (foundIt != attributeBuilderMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
attributeBuilderMap[attributeKind] = py::none();
return std::nullopt;
}
std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
{
// Fast match against the class map first (common case).
const auto foundIt = typeCasterMapCache.find(mlirTypeID);
if (foundIt != typeCasterMapCache.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::function is defined");
return foundIt->second;
}
}
// Not found. Load the dialect namespace.
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
// Attempt to find from the canonical map and cache.
{
const auto foundIt = typeCasterMap.find(mlirTypeID);
if (foundIt != typeCasterMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
// Positive cache.
typeCasterMapCache[mlirTypeID] = foundIt->second;
return foundIt->second;
}
// Negative cache.
typeCasterMap[mlirTypeID] = py::none();
return std::nullopt;
}
}
std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
loadDialectModule(dialectNamespace);
// Fast match against the class map first (common case).
const auto foundIt = dialectClassMap.find(dialectNamespace);
if (foundIt != dialectClassMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
// Not found and loading did not yield a registration. Negative cache.
dialectClassMap[dialectNamespace] = py::none();
return std::nullopt;
}
std::optional<pybind11::object>
PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
{
auto foundIt = operationClassMapCache.find(operationName);
if (foundIt != operationClassMapCache.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
return foundIt->second;
}
}
// Not found. Load the dialect namespace.
auto split = operationName.split('.');
llvm::StringRef dialectNamespace = split.first;
loadDialectModule(dialectNamespace);
// Attempt to find from the canonical map and cache.
{
auto foundIt = operationClassMap.find(operationName);
if (foundIt != operationClassMap.end()) {
if (foundIt->second.is_none())
return std::nullopt;
assert(foundIt->second && "py::object is defined");
// Positive cache.
operationClassMapCache[operationName] = foundIt->second;
return foundIt->second;
}
// Negative cache.
operationClassMap[operationName] = py::none();
return std::nullopt;
}
}
void PyGlobals::clearImportCache() {
loadedDialectModulesCache.clear();
operationClassMapCache.clear();
typeCasterMapCache.clear();
}