llvm-project/mlir/test/mlir-tblgen/enums-python-bindings.td
Maksim Levental c3f381ccfe
[mlir-python] Fix duplicate EnumAttr builder registration across dialects. (#187191)
When multiple dialects share td `#includes` (e.g. `affine` includes
`arith`), each dialect's `*_enum_gen.py` file registers attribute
builders under the same keys, causing "already registered" errors on the
second import; the first commit checks in such a case which currently
fails on main:

```
# | RuntimeError: Attribute builder for 'Arith_CmpFPredicateAttr' is already registered with func: <function _arith_cmpfpredicateattr at 0x78d13cbe9a80>
```

This PR implements a two-pronged fix:

1. Add `allow_existing=True` to `register_attribute_builder` (and the
underlying C++ `registerAttributeBuilder`). When set, silently skips
registration if the key already exists (first-wins semantics). This
handles `EnumInfo`-based builders which have no dialect prefix (e.g.
`AtomicRMWKindAttr`, `Arith_CmpFPredicateAttr`), which may be emitted by
every dialect whose td file includes the defining file;
2. Filter `EnumAttr` builders by `-bind-dialect` in
`EnumPythonBindingGen.cpp` and register them under dialect qualified
keys (`"dialect.AttrName"`). Update `OpPythonBindingGen.cpp` to look up
the same qualified keys for EnumAttr typed op attributes (detected via
`isSubClassOf("EnumAttr")`). Pass `-bind-dialect` from
`AddMLIRPython.cmake`.

This approach incurs no changes to `ir.py` registrations (no "builtin."
prefix), and no manual builder additions to individual dialect Python
files (unlike the previous attempt
https://github.com/llvm/llvm-project/pull/117918).

Note, this PR was "clauded" not "coded".
2026-03-19 21:02:23 -07:00

116 lines
4.2 KiB
TableGen

// RUN: mlir-tblgen -gen-python-enum-bindings %s -I %S/../../include | FileCheck %s
include "mlir/IR/EnumAttr.td"
def Test_Dialect : Dialect {
let name = "TestDialect";
let cppNamespace = "::test";
}
// CHECK: Autogenerated by mlir-tblgen; don't manually edit.
// CHECK: from enum import IntEnum, auto, IntFlag
// CHECK: from ._ods_common import _cext as _ods_cext
// CHECK: from ..ir import register_attribute_builder
// CHECK: _ods_ir = _ods_cext.ir
def One : I32EnumAttrCase<"CaseOne", 1, "one">;
def Two : I32EnumAttrCase<"CaseTwo", 2, "two">;
def NegOne : I32EnumAttrCase<"CaseNegOne", -1, "negone">;
def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>;
// CHECK-LABEL: class MyEnum(IntEnum):
// CHECK: """An example 32-bit enum"""
// CHECK: CaseOne = 1
// CHECK: CaseTwo = 2
// CHECK: CaseNegOne = auto()
// CHECK: def __str__(self):
// CHECK: if self is MyEnum.CaseOne:
// CHECK: return "one"
// CHECK: if self is MyEnum.CaseTwo:
// CHECK: return "two"
// CHECK: if self is MyEnum.CaseNegOne:
// CHECK: return "negone"
// CHECK: raise ValueError("Unknown MyEnum enum entry.")
// CHECK: @register_attribute_builder("MyEnum", allow_existing=True)
// CHECK: def _myenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
def TestMyEnum_Attr : EnumAttr<Test_Dialect, MyEnum, "enum">;
def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">;
def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>;
// CHECK-LABEL: class MyEnum64(IntEnum):
// CHECK: """An example 64-bit enum"""
// CHECK: CaseOne64 = 1
// CHECK: CaseTwo64 = 2
// CHECK: def __str__(self):
// CHECK: if self is MyEnum64.CaseOne64:
// CHECK: return "one"
// CHECK: if self is MyEnum64.CaseTwo64:
// CHECK: return "two"
// CHECK: raise ValueError("Unknown MyEnum64 enum entry.")
// CHECK: @register_attribute_builder("MyEnum64", allow_existing=True)
// CHECK: def _myenum64(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
def User : I32BitEnumAttrCaseBit<"User", 0, "user">;
def Group : I32BitEnumAttrCaseBit<"Group", 1, "group">;
def Other : I32BitEnumAttrCaseBit<"Other", 2, "other">;
def TestBitEnum
: I32BitEnumAttr<
"TestBitEnum", "",
[User, Group, Other,
I32BitEnumAttrCaseGroup<"Any", [User, Group, Other], "any">]> {
let genSpecializedAttr = 0;
let separator = " | ";
}
def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
// CHECK-LABEL: class TestBitEnum(IntFlag):
// CHECK: User = 1
// CHECK: Group = 2
// CHECK: Other = 4
// CHECK: Any = 7
// CHECK: def __iter__(self):
// CHECK: return iter([case for case in type(self) if (self & case) is case and self is not case])
// CHECK: def __len__(self):
// CHECK: return bin(self).count("1")
// CHECK: def __str__(self):
// CHECK: if len(self) > 1:
// CHECK: return " | ".join(map(str, self))
// CHECK: if self is TestBitEnum.User:
// CHECK: return "user"
// CHECK: if self is TestBitEnum.Group:
// CHECK: return "group"
// CHECK: if self is TestBitEnum.Other:
// CHECK: return "other"
// CHECK: if self is TestBitEnum.Any:
// CHECK: return "any"
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")
// CHECK: @register_attribute_builder("TestBitEnum", allow_existing=True)
// CHECK: def _testbitenum(x, context):
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
// CHECK: @register_attribute_builder("TestDialect.TestBitEnum_Attr")
// CHECK: def _testbitenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
// CHECK: @register_attribute_builder("TestDialect.TestMyEnum_Attr")
// CHECK: def _testmyenum_attr(x, context):
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)