When multiple dialects share td `#includes` (e.g. `affine` includes
`arith`), each dialect's `*_enum_gen.py` file registers attribute
builders under the same keys, causing "already registered" errors on the
second import; the first commit checks in such a case which currently
fails on main:
```
# | RuntimeError: Attribute builder for 'Arith_CmpFPredicateAttr' is already registered with func: <function _arith_cmpfpredicateattr at 0x78d13cbe9a80>
```
This PR implements a two-pronged fix:
1. Add `allow_existing=True` to `register_attribute_builder` (and the
underlying C++ `registerAttributeBuilder`). When set, silently skips
registration if the key already exists (first-wins semantics). This
handles `EnumInfo`-based builders which have no dialect prefix (e.g.
`AtomicRMWKindAttr`, `Arith_CmpFPredicateAttr`), which may be emitted by
every dialect whose td file includes the defining file;
2. Filter `EnumAttr` builders by `-bind-dialect` in
`EnumPythonBindingGen.cpp` and register them under dialect qualified
keys (`"dialect.AttrName"`). Update `OpPythonBindingGen.cpp` to look up
the same qualified keys for EnumAttr typed op attributes (detected via
`isSubClassOf("EnumAttr")`). Pass `-bind-dialect` from
`AddMLIRPython.cmake`.
This approach incurs no changes to `ir.py` registrations (no "builtin."
prefix), and no manual builder additions to individual dialect Python
files (unlike the previous attempt
https://github.com/llvm/llvm-project/pull/117918).
Note, this PR was "clauded" not "coded".
182 lines
7.1 KiB
C++
182 lines
7.1 KiB
C++
//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
|
|
// generate the corresponding Python binding classes.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
#include "OpGenHelpers.h"
|
|
|
|
#include "mlir/TableGen/AttrOrTypeDef.h"
|
|
#include "mlir/TableGen/Attribute.h"
|
|
#include "mlir/TableGen/Dialect.h"
|
|
#include "mlir/TableGen/EnumInfo.h"
|
|
#include "mlir/TableGen/GenInfo.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tblgen;
|
|
using llvm::formatv;
|
|
using llvm::Record;
|
|
using llvm::RecordKeeper;
|
|
|
|
// Declared in OpPythonBindingGen.cpp; the two generators share the same
|
|
// -bind-dialect option to allow filtering enum registrations by dialect.
|
|
extern std::string dialectNameStorage;
|
|
|
|
/// File header and includes.
|
|
constexpr const char *fileHeader = R"Py(
|
|
# Autogenerated by mlir-tblgen; don't manually edit.
|
|
|
|
from enum import IntEnum, auto, IntFlag
|
|
from ._ods_common import _cext as _ods_cext
|
|
from ..ir import register_attribute_builder
|
|
_ods_ir = _ods_cext.ir
|
|
|
|
)Py";
|
|
|
|
/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
|
|
static std::string makePythonEnumCaseName(StringRef name) {
|
|
if (isPythonReserved(name.str()))
|
|
return (name + "_").str();
|
|
return name.str();
|
|
}
|
|
|
|
/// Emits the Python class for the given enum.
|
|
static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) {
|
|
os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(),
|
|
enumInfo.isBitEnum() ? "IntFlag" : "IntEnum");
|
|
if (!enumInfo.getSummary().empty())
|
|
os << formatv(" \"\"\"{0}\"\"\"\n", enumInfo.getSummary());
|
|
os << "\n";
|
|
|
|
for (const EnumCase &enumCase : enumInfo.getAllCases()) {
|
|
os << formatv(" {0} = {1}\n",
|
|
makePythonEnumCaseName(enumCase.getSymbol()),
|
|
enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue())
|
|
: "auto()");
|
|
}
|
|
|
|
os << "\n";
|
|
|
|
if (enumInfo.isBitEnum()) {
|
|
os << formatv(" def __iter__(self):\n"
|
|
" return iter([case for case in type(self) if "
|
|
"(self & case) is case and self is not case])\n");
|
|
os << formatv(" def __len__(self):\n"
|
|
" return bin(self).count(\"1\")\n");
|
|
os << "\n";
|
|
}
|
|
|
|
os << formatv(" def __str__(self):\n");
|
|
if (enumInfo.isBitEnum())
|
|
os << formatv(" if len(self) > 1:\n"
|
|
" return \"{0}\".join(map(str, self))\n",
|
|
enumInfo.getDef().getValueAsString("separator"));
|
|
for (const EnumCase &enumCase : enumInfo.getAllCases()) {
|
|
os << formatv(" if self is {0}.{1}:\n", enumInfo.getEnumClassName(),
|
|
makePythonEnumCaseName(enumCase.getSymbol()));
|
|
os << formatv(" return \"{0}\"\n", enumCase.getStr());
|
|
}
|
|
os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n",
|
|
enumInfo.getEnumClassName());
|
|
os << "\n";
|
|
}
|
|
|
|
/// Emits an attribute builder for the given enum attribute to support automatic
|
|
/// conversion between enum values and attributes in Python. Returns
|
|
/// `false` on success, `true` on failure.
|
|
static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
|
|
std::optional<Attribute> enumAttrInfo = enumInfo.asEnumAttr();
|
|
if (!enumAttrInfo)
|
|
return false;
|
|
|
|
int64_t bitwidth = enumInfo.getBitwidth();
|
|
// These builders may be emitted by multiple dialect enum_gen files when
|
|
// dialects share enum definitions via .td includes. Use allow_existing=True
|
|
// so that the first loaded dialect registers the builder and subsequent
|
|
// loads silently skip (first-registration wins).
|
|
os << formatv("@register_attribute_builder(\"{0}\", allow_existing=True)\n",
|
|
enumAttrInfo->getAttrDefName());
|
|
os << formatv("def _{0}(x, context):\n",
|
|
enumAttrInfo->getAttrDefName().lower());
|
|
os << formatv(" return "
|
|
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
|
|
"context=context), int(x))\n\n",
|
|
bitwidth);
|
|
return false;
|
|
}
|
|
|
|
/// Emits an attribute builder for the given dialect enum attribute to support
|
|
/// automatic conversion between enum values and attributes in Python. Returns
|
|
/// `false` on success, `true` on failure.
|
|
static bool emitDialectEnumAttributeBuilder(StringRef dialect,
|
|
StringRef attrDefName,
|
|
StringRef formatString,
|
|
raw_ostream &os) {
|
|
os << formatv("@register_attribute_builder(\"{0}.{1}\")\n", dialect,
|
|
attrDefName);
|
|
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
|
|
os << formatv(" return "
|
|
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
|
|
formatString);
|
|
return false;
|
|
}
|
|
|
|
/// Emits Python bindings for all enums in the record keeper. Returns
|
|
/// `false` on success, `true` on failure.
|
|
static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
|
|
os << fileHeader;
|
|
for (const Record *it :
|
|
records.getAllDerivedDefinitionsIfDefined("EnumInfo")) {
|
|
EnumInfo enumInfo(*it);
|
|
emitEnumClass(enumInfo, os);
|
|
emitAttributeBuilder(enumInfo, os);
|
|
}
|
|
for (const Record *it :
|
|
records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
|
|
AttrOrTypeDef attr(&*it);
|
|
StringRef dialect = attr.getDialect().getName();
|
|
// When -bind-dialect is specified, only emit builders for EnumAttr records
|
|
// belonging to that dialect. This prevents duplicate registrations when
|
|
// multiple dialects include the same .td files.
|
|
if (!dialectNameStorage.empty() && dialect != dialectNameStorage)
|
|
continue;
|
|
if (!attr.getMnemonic()) {
|
|
llvm::errs() << "enum case " << attr
|
|
<< " needs mnemonic for python enum bindings generation";
|
|
return true;
|
|
}
|
|
StringRef mnemonic = attr.getMnemonic().value();
|
|
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
|
|
if (assemblyFormat == "`<` $value `>`") {
|
|
emitDialectEnumAttributeBuilder(
|
|
dialect, attr.getName(),
|
|
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
|
|
} else if (assemblyFormat == "$value") {
|
|
emitDialectEnumAttributeBuilder(
|
|
dialect, attr.getName(),
|
|
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
|
|
} else {
|
|
llvm::errs()
|
|
<< "unsupported assembly format for python enum bindings generation";
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
// Registers the enum utility generator to mlir-tblgen.
|
|
static mlir::GenRegistration
|
|
genPythonEnumBindings("gen-python-enum-bindings",
|
|
"Generate Python bindings for enum attributes",
|
|
&emitPythonEnums);
|