[mlir-python] Fix duplicate EnumAttr builder registration across dialects. (#187191)

When multiple dialects share td `#includes` (e.g. `affine` includes
`arith`), each dialect's `*_enum_gen.py` file registers attribute
builders under the same keys, causing "already registered" errors on the
second import; the first commit checks in such a case which currently
fails on main:

```
# | RuntimeError: Attribute builder for 'Arith_CmpFPredicateAttr' is already registered with func: <function _arith_cmpfpredicateattr at 0x78d13cbe9a80>
```

This PR implements a two-pronged fix:

1. Add `allow_existing=True` to `register_attribute_builder` (and the
underlying C++ `registerAttributeBuilder`). When set, silently skips
registration if the key already exists (first-wins semantics). This
handles `EnumInfo`-based builders which have no dialect prefix (e.g.
`AtomicRMWKindAttr`, `Arith_CmpFPredicateAttr`), which may be emitted by
every dialect whose td file includes the defining file;
2. Filter `EnumAttr` builders by `-bind-dialect` in
`EnumPythonBindingGen.cpp` and register them under dialect qualified
keys (`"dialect.AttrName"`). Update `OpPythonBindingGen.cpp` to look up
the same qualified keys for EnumAttr typed op attributes (detected via
`isSubClassOf("EnumAttr")`). Pass `-bind-dialect` from
`AddMLIRPython.cmake`.

This approach incurs no changes to `ir.py` registrations (no "builtin."
prefix), and no manual builder additions to individual dialect Python
files (unlike the previous attempt
https://github.com/llvm/llvm-project/pull/117918).

Note, this PR was "clauded" not "coded".
This commit is contained in:
Maksim Levental 2026-03-19 21:02:23 -07:00 committed by GitHub
parent fa2df7e853
commit c3f381ccfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 92 additions and 33 deletions

View File

@ -608,7 +608,7 @@ function(declare_mlir_dialect_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()
@ -680,7 +680,7 @@ function(declare_mlir_dialect_extension_python_bindings)
set(LLVM_TARGET_DEFINITIONS ${td_file})
endif()
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
list(APPEND _sources ${enum_filename})
endif()

View File

@ -58,11 +58,14 @@ public:
bool loadDialectModule(std::string_view dialectNamespace);
/// Adds a user-friendly Attribute builder.
/// Raises an exception if the mapping already exists and replace == false.
/// Raises an exception if the mapping already exists and replace == false
/// and allow_existing == false.
/// Silently skips registration if allow_existing == true and the mapping
/// already exists (first registration wins).
/// This is intended to be called by implementation code.
void registerAttributeBuilder(const std::string &attributeKind,
nanobind::callable pyFunc,
bool replace = false);
nanobind::callable pyFunc, bool replace = false,
bool allow_existing = false);
/// Adds a user-friendly type caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by

View File

@ -1356,7 +1356,8 @@ struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
static nanobind::callable
dunderGetItemNamed(const std::string &attributeKind);
static void dunderSetItemNamed(const std::string &attributeKind,
nanobind::callable func, bool replace);
nanobind::callable func, bool replace,
bool allow_existing);
static void bind(nanobind::module_ &m);
};

View File

@ -97,14 +97,27 @@ bool PyGlobals::loadDialectModule(std::string_view dialectNamespace) {
}
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
nb::callable pyFunc, bool replace) {
nb::callable pyFunc, bool replace,
bool allowExisting) {
nb::ft_lock_guard lock(mutex);
nb::object &found = attributeBuilderMap[attributeKind];
if (found && !replace) {
throw std::runtime_error(
if (found) {
std::string msg =
nanobind::detail::join("Attribute builder for '", attributeKind,
"' is already registered with func: ",
nb::cast<std::string>(nb::str(found))));
nb::cast<std::string>(nb::str(found)));
if (allowExisting) {
#ifndef NDEBUG
if (PyErr_WarnEx(PyExc_RuntimeWarning, msg.c_str(), 1) < 0) {
// If the user has set warnings to errors (e.g., via -Werror),
// PyErr_WarnEx returns -1 and sets a Python exception.
throw nb::python_error();
}
#endif
return;
}
if (!replace)
throw std::runtime_error(msg);
}
found = std::move(pyFunc);
}

View File

@ -154,9 +154,10 @@ PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
}
void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
nb::callable func, bool replace) {
nb::callable func, bool replace,
bool allow_existing) {
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
replace);
replace, allow_existing);
}
void PyAttrBuilderMap::bind(nb::module_ &m) {
@ -171,6 +172,7 @@ void PyAttrBuilderMap::bind(nb::module_ &m) {
"attribute kind.")
.def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
"attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
"allow_existing"_a = false,
"Register an attribute builder for building MLIR "
"attributes from Python values.");
}

View File

@ -95,9 +95,9 @@ def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None, None, Non
# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind, replace=False):
def register_attribute_builder(kind, replace=False, allow_existing=False):
def decorator_builder(func):
AttrBuilder.insert(kind, func, replace=replace)
AttrBuilder.insert(kind, func, replace=replace, allow_existing=allow_existing)
return func
return decorator_builder

View File

@ -35,7 +35,7 @@ def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>
// CHECK: return "negone"
// CHECK: raise ValueError("Unknown MyEnum enum entry.")
// CHECK: @register_attribute_builder("MyEnum")
// CHECK: @register_attribute_builder("MyEnum", allow_existing=True)
// CHECK: def _myenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
@ -58,7 +58,7 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
// CHECK: return "two"
// CHECK: raise ValueError("Unknown MyEnum64 enum entry.")
// CHECK: @register_attribute_builder("MyEnum64")
// CHECK: @register_attribute_builder("MyEnum64", allow_existing=True)
// CHECK: def _myenum64(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
@ -102,14 +102,14 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
// CHECK: return "any"
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")
// CHECK: @register_attribute_builder("TestBitEnum")
// CHECK: @register_attribute_builder("TestBitEnum", allow_existing=True)
// CHECK: def _testbitenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
// CHECK: @register_attribute_builder("TestBitEnum_Attr")
// CHECK: @register_attribute_builder("TestDialect.TestBitEnum_Attr")
// CHECK: def _testbitenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
// CHECK: @register_attribute_builder("TestMyEnum_Attr")
// CHECK: @register_attribute_builder("TestDialect.TestMyEnum_Attr")
// CHECK: def _testmyenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)

View File

@ -335,6 +335,11 @@ def testAffineIfWithElse():
return
@constructAndPrintInModule
def test_double_AtomicRMWKindAttr_registration():
from mlir.dialects import _affine_enum_gen
# CHECK-LABEL: TEST: testAffineIfOpInsertionPoint
@constructAndPrintInModule
def testAffineIfOpInsertionPoint():

View File

@ -94,7 +94,7 @@ def testCeilDivUOp(ctx):
def testCmpOp(ctx):
a = index.ConstantOp(value=42)
b = index.ConstantOp(value=23)
pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx)
pred = AttrBuilder.get("index.IndexCmpPredicateAttr")("slt", context=ctx)
r = index.CmpOp(pred, lhs=a, rhs=b)
# CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}})

View File

@ -17,6 +17,7 @@
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/EnumInfo.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Record.h"
@ -26,6 +27,10 @@ using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
// Declared in OpPythonBindingGen.cpp; the two generators share the same
// -bind-dialect option to allow filtering enum registrations by dialect.
extern std::string dialectNameStorage;
/// File header and includes.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
@ -94,7 +99,11 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
return false;
int64_t bitwidth = enumInfo.getBitwidth();
os << formatv("@register_attribute_builder(\"{0}\")\n",
// These builders may be emitted by multiple dialect enum_gen files when
// dialects share enum definitions via .td includes. Use allow_existing=True
// so that the first loaded dialect registers the builder and subsequent
// loads silently skip (first-registration wins).
os << formatv("@register_attribute_builder(\"{0}\", allow_existing=True)\n",
enumAttrInfo->getAttrDefName());
os << formatv("def _{0}(x, context):\n",
enumAttrInfo->getAttrDefName().lower());
@ -108,10 +117,12 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
/// Emits an attribute builder for the given dialect enum attribute to support
/// automatic conversion between enum values and attributes in Python. Returns
/// `false` on success, `true` on failure.
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
static bool emitDialectEnumAttributeBuilder(StringRef dialect,
StringRef attrDefName,
StringRef formatString,
raw_ostream &os) {
os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
os << formatv("@register_attribute_builder(\"{0}.{1}\")\n", dialect,
attrDefName);
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
os << formatv(" return "
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
@ -132,6 +143,12 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
for (const Record *it :
records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
AttrOrTypeDef attr(&*it);
StringRef dialect = attr.getDialect().getName();
// When -bind-dialect is specified, only emit builders for EnumAttr records
// belonging to that dialect. This prevents duplicate registrations when
// multiple dialects include the same .td files.
if (!dialectNameStorage.empty() && dialect != dialectNameStorage)
continue;
if (!attr.getMnemonic()) {
llvm::errs() << "enum case " << attr
<< " needs mnemonic for python enum bindings generation";
@ -139,14 +156,13 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
}
StringRef mnemonic = attr.getMnemonic().value();
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
StringRef dialect = attr.getDialect().getName();
if (assemblyFormat == "`<` $value `>`") {
emitDialectEnumAttributeBuilder(
attr.getName(),
dialect, attr.getName(),
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
} else if (assemblyFormat == "$value") {
emitDialectEnumAttributeBuilder(
attr.getName(),
dialect, attr.getName(),
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
} else {
llvm::errs()

View File

@ -341,10 +341,13 @@ def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
static llvm::cl::opt<std::string>
std::string dialectNameStorage;
llvm::cl::opt<std::string, /*ExternalStorage=*/true>
clDialectName("bind-dialect",
llvm::cl::desc("The dialect to run the generator for"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
llvm::cl::location(dialectNameStorage),
llvm::cl::cat(clOpPythonBindingCat));
static llvm::cl::opt<std::string> clDialectExtensionName(
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
@ -887,11 +890,27 @@ populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
continue;
}
// For EnumAttr-style attributes (those defined as EnumAttr<Dialect, ...>
// in tablegen), use a dialect-qualified key ("dialect.AttrName") so the
// lookup matches the registration emitted by EnumPythonBindingGen with
// -bind-dialect. For all other attributes (plain attrs like I32Attr,
// custom AttrDef, etc.), keep the unqualified name to match their
// registrations in ir.py or dialect-specific Python files.
Attribute baseAttr = attribute->attr.getBaseAttr();
Dialect attrDialect = baseAttr.isSubClassOf("EnumAttr")
? baseAttr.getDialect()
: Dialect(nullptr);
std::string attrBuilderKey = attrDialect
? formatv("{0}.{1}", attrDialect.getName(),
attribute->attr.getAttrDefName())
.str()
: attribute->attr.getAttrDefName().str();
builderLines.push_back(formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
argNames[i], attribute->name, attrBuilderKey));
}
}
@ -1307,18 +1326,18 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
if (dialectNameStorage.empty())
llvm::PrintFatalError("dialect name not provided");
os << fileHeader;
if (!clDialectExtensionName.empty())
os << formatv(dialectExtensionTemplate, clDialectName.getValue());
os << formatv(dialectExtensionTemplate, dialectNameStorage);
else
os << formatv(dialectClassTemplate, clDialectName.getValue());
os << formatv(dialectClassTemplate, dialectNameStorage);
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
if (op.getDialectName() == clDialectName.getValue())
if (op.getDialectName() == dialectNameStorage)
emitOpBindings(op, os);
}
return false;