From c3f381ccfe4b48f204df07e2c8cd36542a60d553 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 19 Mar 2026 21:02:23 -0700 Subject: [PATCH] [mlir-python] Fix duplicate EnumAttr builder registration across dialects. (#187191) When multiple dialects share td `#includes` (e.g. `affine` includes `arith`), each dialect's `*_enum_gen.py` file registers attribute builders under the same keys, causing "already registered" errors on the second import; the first commit checks in such a case which currently fails on main: ``` # | RuntimeError: Attribute builder for 'Arith_CmpFPredicateAttr' is already registered with func: ``` This PR implements a two-pronged fix: 1. Add `allow_existing=True` to `register_attribute_builder` (and the underlying C++ `registerAttributeBuilder`). When set, silently skips registration if the key already exists (first-wins semantics). This handles `EnumInfo`-based builders which have no dialect prefix (e.g. `AtomicRMWKindAttr`, `Arith_CmpFPredicateAttr`), which may be emitted by every dialect whose td file includes the defining file; 2. Filter `EnumAttr` builders by `-bind-dialect` in `EnumPythonBindingGen.cpp` and register them under dialect qualified keys (`"dialect.AttrName"`). Update `OpPythonBindingGen.cpp` to look up the same qualified keys for EnumAttr typed op attributes (detected via `isSubClassOf("EnumAttr")`). Pass `-bind-dialect` from `AddMLIRPython.cmake`. This approach incurs no changes to `ir.py` registrations (no "builtin." prefix), and no manual builder additions to individual dialect Python files (unlike the previous attempt https://github.com/llvm/llvm-project/pull/117918). Note, this PR was "clauded" not "coded". --- mlir/cmake/modules/AddMLIRPython.cmake | 4 +-- mlir/include/mlir/Bindings/Python/Globals.h | 9 +++-- mlir/include/mlir/Bindings/Python/IRCore.h | 3 +- mlir/lib/Bindings/Python/Globals.cpp | 21 +++++++++--- mlir/lib/Bindings/Python/IRCore.cpp | 6 ++-- mlir/python/mlir/ir.py | 4 +-- .../test/mlir-tblgen/enums-python-bindings.td | 10 +++--- mlir/test/python/dialects/affine.py | 5 +++ mlir/test/python/dialects/index_dialect.py | 2 +- .../mlir-tblgen/EnumPythonBindingGen.cpp | 28 ++++++++++++---- mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 33 +++++++++++++++---- 11 files changed, 92 insertions(+), 33 deletions(-) diff --git a/mlir/cmake/modules/AddMLIRPython.cmake b/mlir/cmake/modules/AddMLIRPython.cmake index 6ac5003538e4..07f97e3261a3 100644 --- a/mlir/cmake/modules/AddMLIRPython.cmake +++ b/mlir/cmake/modules/AddMLIRPython.cmake @@ -608,7 +608,7 @@ function(declare_mlir_dialect_python_bindings) set(LLVM_TARGET_DEFINITIONS ${td_file}) endif() set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py") - mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME}) list(APPEND _sources ${enum_filename}) endif() @@ -680,7 +680,7 @@ function(declare_mlir_dialect_extension_python_bindings) set(LLVM_TARGET_DEFINITIONS ${td_file}) endif() set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py") - mlir_tablegen(${enum_filename} -gen-python-enum-bindings) + mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME}) list(APPEND _sources ${enum_filename}) endif() diff --git a/mlir/include/mlir/Bindings/Python/Globals.h b/mlir/include/mlir/Bindings/Python/Globals.h index 8a7f30fd218d..b2aa169744c9 100644 --- a/mlir/include/mlir/Bindings/Python/Globals.h +++ b/mlir/include/mlir/Bindings/Python/Globals.h @@ -58,11 +58,14 @@ public: bool loadDialectModule(std::string_view dialectNamespace); /// Adds a user-friendly Attribute builder. - /// Raises an exception if the mapping already exists and replace == false. + /// Raises an exception if the mapping already exists and replace == false + /// and allow_existing == false. + /// Silently skips registration if allow_existing == true and the mapping + /// already exists (first registration wins). /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, - nanobind::callable pyFunc, - bool replace = false); + nanobind::callable pyFunc, bool replace = false, + bool allow_existing = false); /// Adds a user-friendly type caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h index 557e32e9a612..f24b3c6ac6f8 100644 --- a/mlir/include/mlir/Bindings/Python/IRCore.h +++ b/mlir/include/mlir/Bindings/Python/IRCore.h @@ -1356,7 +1356,8 @@ struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap { static nanobind::callable dunderGetItemNamed(const std::string &attributeKind); static void dunderSetItemNamed(const std::string &attributeKind, - nanobind::callable func, bool replace); + nanobind::callable func, bool replace, + bool allow_existing); static void bind(nanobind::module_ &m); }; diff --git a/mlir/lib/Bindings/Python/Globals.cpp b/mlir/lib/Bindings/Python/Globals.cpp index 82195acb9f4f..1e48eac27dd8 100644 --- a/mlir/lib/Bindings/Python/Globals.cpp +++ b/mlir/lib/Bindings/Python/Globals.cpp @@ -97,14 +97,27 @@ bool PyGlobals::loadDialectModule(std::string_view dialectNamespace) { } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - nb::callable pyFunc, bool replace) { + nb::callable pyFunc, bool replace, + bool allowExisting) { nb::ft_lock_guard lock(mutex); nb::object &found = attributeBuilderMap[attributeKind]; - if (found && !replace) { - throw std::runtime_error( + if (found) { + std::string msg = nanobind::detail::join("Attribute builder for '", attributeKind, "' is already registered with func: ", - nb::cast(nb::str(found)))); + nb::cast(nb::str(found))); + if (allowExisting) { +#ifndef NDEBUG + if (PyErr_WarnEx(PyExc_RuntimeWarning, msg.c_str(), 1) < 0) { + // If the user has set warnings to errors (e.g., via -Werror), + // PyErr_WarnEx returns -1 and sets a Python exception. + throw nb::python_error(); + } +#endif + return; + } + if (!replace) + throw std::runtime_error(msg); } found = std::move(pyFunc); } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3d07e364b5c9..89e1e21cd124 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -154,9 +154,10 @@ PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) { } void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind, - nb::callable func, bool replace) { + nb::callable func, bool replace, + bool allow_existing) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), - replace); + replace, allow_existing); } void PyAttrBuilderMap::bind(nb::module_ &m) { @@ -171,6 +172,7 @@ void PyAttrBuilderMap::bind(nb::module_ &m) { "attribute kind.") .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, + "allow_existing"_a = false, "Register an attribute builder for building MLIR " "attributes from Python values."); } diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 210465daad0d..3795f5cb2e03 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -95,9 +95,9 @@ def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None, None, Non # Convenience decorator for registering user-friendly Attribute builders. -def register_attribute_builder(kind, replace=False): +def register_attribute_builder(kind, replace=False, allow_existing=False): def decorator_builder(func): - AttrBuilder.insert(kind, func, replace=replace) + AttrBuilder.insert(kind, func, replace=replace, allow_existing=allow_existing) return func return decorator_builder diff --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td index cd23b6a2effb..74b9f51b0c2d 100644 --- a/mlir/test/mlir-tblgen/enums-python-bindings.td +++ b/mlir/test/mlir-tblgen/enums-python-bindings.td @@ -35,7 +35,7 @@ def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]> // CHECK: return "negone" // CHECK: raise ValueError("Unknown MyEnum enum entry.") -// CHECK: @register_attribute_builder("MyEnum") +// CHECK: @register_attribute_builder("MyEnum", allow_existing=True) // CHECK: def _myenum(x, context): // CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x)) @@ -58,7 +58,7 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]> // CHECK: return "two" // CHECK: raise ValueError("Unknown MyEnum64 enum entry.") -// CHECK: @register_attribute_builder("MyEnum64") +// CHECK: @register_attribute_builder("MyEnum64", allow_existing=True) // CHECK: def _myenum64(x, context): // CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x)) @@ -102,14 +102,14 @@ def TestBitEnum_Attr : EnumAttr; // CHECK: return "any" // CHECK: raise ValueError("Unknown TestBitEnum enum entry.") -// CHECK: @register_attribute_builder("TestBitEnum") +// CHECK: @register_attribute_builder("TestBitEnum", allow_existing=True) // CHECK: def _testbitenum(x, context): // CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x)) -// CHECK: @register_attribute_builder("TestBitEnum_Attr") +// CHECK: @register_attribute_builder("TestDialect.TestBitEnum_Attr") // CHECK: def _testbitenum_attr(x, context): // CHECK: return _ods_ir.Attribute.parse(f'#TestDialect', context=context) -// CHECK: @register_attribute_builder("TestMyEnum_Attr") +// CHECK: @register_attribute_builder("TestDialect.TestMyEnum_Attr") // CHECK: def _testmyenum_attr(x, context): // CHECK: return _ods_ir.Attribute.parse(f'#TestDialect', context=context) diff --git a/mlir/test/python/dialects/affine.py b/mlir/test/python/dialects/affine.py index c797234fd16d..1b1655692f15 100644 --- a/mlir/test/python/dialects/affine.py +++ b/mlir/test/python/dialects/affine.py @@ -335,6 +335,11 @@ def testAffineIfWithElse(): return +@constructAndPrintInModule +def test_double_AtomicRMWKindAttr_registration(): + from mlir.dialects import _affine_enum_gen + + # CHECK-LABEL: TEST: testAffineIfOpInsertionPoint @constructAndPrintInModule def testAffineIfOpInsertionPoint(): diff --git a/mlir/test/python/dialects/index_dialect.py b/mlir/test/python/dialects/index_dialect.py index 9db883469792..8da6a262cc44 100644 --- a/mlir/test/python/dialects/index_dialect.py +++ b/mlir/test/python/dialects/index_dialect.py @@ -94,7 +94,7 @@ def testCeilDivUOp(ctx): def testCmpOp(ctx): a = index.ConstantOp(value=42) b = index.ConstantOp(value=23) - pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx) + pred = AttrBuilder.get("index.IndexCmpPredicateAttr")("slt", context=ctx) r = index.CmpOp(pred, lhs=a, rhs=b) # CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}}) diff --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp index acc9b61d7121..6cef09d9958c 100644 --- a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp @@ -17,6 +17,7 @@ #include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/GenInfo.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" @@ -26,6 +27,10 @@ using llvm::formatv; using llvm::Record; using llvm::RecordKeeper; +// Declared in OpPythonBindingGen.cpp; the two generators share the same +// -bind-dialect option to allow filtering enum registrations by dialect. +extern std::string dialectNameStorage; + /// File header and includes. constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. @@ -94,7 +99,11 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) { return false; int64_t bitwidth = enumInfo.getBitwidth(); - os << formatv("@register_attribute_builder(\"{0}\")\n", + // These builders may be emitted by multiple dialect enum_gen files when + // dialects share enum definitions via .td includes. Use allow_existing=True + // so that the first loaded dialect registers the builder and subsequent + // loads silently skip (first-registration wins). + os << formatv("@register_attribute_builder(\"{0}\", allow_existing=True)\n", enumAttrInfo->getAttrDefName()); os << formatv("def _{0}(x, context):\n", enumAttrInfo->getAttrDefName().lower()); @@ -108,10 +117,12 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) { /// Emits an attribute builder for the given dialect enum attribute to support /// automatic conversion between enum values and attributes in Python. Returns /// `false` on success, `true` on failure. -static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, +static bool emitDialectEnumAttributeBuilder(StringRef dialect, + StringRef attrDefName, StringRef formatString, raw_ostream &os) { - os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); + os << formatv("@register_attribute_builder(\"{0}.{1}\")\n", dialect, + attrDefName); os << formatv("def _{0}(x, context):\n", attrDefName.lower()); os << formatv(" return " "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", @@ -132,6 +143,12 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) { for (const Record *it : records.getAllDerivedDefinitionsIfDefined("EnumAttr")) { AttrOrTypeDef attr(&*it); + StringRef dialect = attr.getDialect().getName(); + // When -bind-dialect is specified, only emit builders for EnumAttr records + // belonging to that dialect. This prevents duplicate registrations when + // multiple dialects include the same .td files. + if (!dialectNameStorage.empty() && dialect != dialectNameStorage) + continue; if (!attr.getMnemonic()) { llvm::errs() << "enum case " << attr << " needs mnemonic for python enum bindings generation"; @@ -139,14 +156,13 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) { } StringRef mnemonic = attr.getMnemonic().value(); std::optional assemblyFormat = attr.getAssemblyFormat(); - StringRef dialect = attr.getDialect().getName(); if (assemblyFormat == "`<` $value `>`") { emitDialectEnumAttributeBuilder( - attr.getName(), + dialect, attr.getName(), formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os); } else if (assemblyFormat == "$value") { emitDialectEnumAttributeBuilder( - attr.getName(), + dialect, attr.getName(), formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os); } else { llvm::errs() diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index e8acf4ce40fc..84dce9bdf0c6 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -341,10 +341,13 @@ def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]: static llvm::cl::OptionCategory clOpPythonBindingCat("Options for -gen-python-op-bindings"); -static llvm::cl::opt +std::string dialectNameStorage; + +llvm::cl::opt clDialectName("bind-dialect", llvm::cl::desc("The dialect to run the generator for"), - llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); + llvm::cl::location(dialectNameStorage), + llvm::cl::cat(clOpPythonBindingCat)); static llvm::cl::opt clDialectExtensionName( "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"), @@ -887,11 +890,27 @@ populateBuilderLinesAttr(const Operator &op, ArrayRef argNames, continue; } + // For EnumAttr-style attributes (those defined as EnumAttr + // in tablegen), use a dialect-qualified key ("dialect.AttrName") so the + // lookup matches the registration emitted by EnumPythonBindingGen with + // -bind-dialect. For all other attributes (plain attrs like I32Attr, + // custom AttrDef, etc.), keep the unqualified name to match their + // registrations in ir.py or dialect-specific Python files. + Attribute baseAttr = attribute->attr.getBaseAttr(); + Dialect attrDialect = baseAttr.isSubClassOf("EnumAttr") + ? baseAttr.getDialect() + : Dialect(nullptr); + std::string attrBuilderKey = attrDialect + ? formatv("{0}.{1}", attrDialect.getName(), + attribute->attr.getAttrDefName()) + .str() + : attribute->attr.getAttrDefName().str(); + builderLines.push_back(formatv( attribute->attr.isOptional() || attribute->attr.hasDefaultValue() ? initOptionalAttributeWithBuilderTemplate : initAttributeWithBuilderTemplate, - argNames[i], attribute->name, attribute->attr.getAttrDefName())); + argNames[i], attribute->name, attrBuilderKey)); } } @@ -1307,18 +1326,18 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) { /// headers and utilities. Returns `false` on success to comply with Tablegen /// registration requirements. static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) { - if (clDialectName.empty()) + if (dialectNameStorage.empty()) llvm::PrintFatalError("dialect name not provided"); os << fileHeader; if (!clDialectExtensionName.empty()) - os << formatv(dialectExtensionTemplate, clDialectName.getValue()); + os << formatv(dialectExtensionTemplate, dialectNameStorage); else - os << formatv(dialectClassTemplate, clDialectName.getValue()); + os << formatv(dialectClassTemplate, dialectNameStorage); for (const Record *rec : records.getAllDerivedDefinitions("Op")) { Operator op(rec); - if (op.getDialectName() == clDialectName.getValue()) + if (op.getDialectName() == dialectNameStorage) emitOpBindings(op, os); } return false;