[mlir-python] Fix duplicate EnumAttr builder registration across dialects. (#187191)
When multiple dialects share td `#includes` (e.g. `affine` includes
`arith`), each dialect's `*_enum_gen.py` file registers attribute
builders under the same keys, causing "already registered" errors on the
second import; the first commit checks in such a case which currently
fails on main:
```
# | RuntimeError: Attribute builder for 'Arith_CmpFPredicateAttr' is already registered with func: <function _arith_cmpfpredicateattr at 0x78d13cbe9a80>
```
This PR implements a two-pronged fix:
1. Add `allow_existing=True` to `register_attribute_builder` (and the
underlying C++ `registerAttributeBuilder`). When set, silently skips
registration if the key already exists (first-wins semantics). This
handles `EnumInfo`-based builders which have no dialect prefix (e.g.
`AtomicRMWKindAttr`, `Arith_CmpFPredicateAttr`), which may be emitted by
every dialect whose td file includes the defining file;
2. Filter `EnumAttr` builders by `-bind-dialect` in
`EnumPythonBindingGen.cpp` and register them under dialect qualified
keys (`"dialect.AttrName"`). Update `OpPythonBindingGen.cpp` to look up
the same qualified keys for EnumAttr typed op attributes (detected via
`isSubClassOf("EnumAttr")`). Pass `-bind-dialect` from
`AddMLIRPython.cmake`.
This approach incurs no changes to `ir.py` registrations (no "builtin."
prefix), and no manual builder additions to individual dialect Python
files (unlike the previous attempt
https://github.com/llvm/llvm-project/pull/117918).
Note, this PR was "clauded" not "coded".
This commit is contained in:
parent
fa2df7e853
commit
c3f381ccfe
@ -608,7 +608,7 @@ function(declare_mlir_dialect_python_bindings)
|
||||
set(LLVM_TARGET_DEFINITIONS ${td_file})
|
||||
endif()
|
||||
set(enum_filename "${relative_td_directory}/_${ARG_DIALECT_NAME}_enum_gen.py")
|
||||
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
|
||||
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
|
||||
list(APPEND _sources ${enum_filename})
|
||||
endif()
|
||||
|
||||
@ -680,7 +680,7 @@ function(declare_mlir_dialect_extension_python_bindings)
|
||||
set(LLVM_TARGET_DEFINITIONS ${td_file})
|
||||
endif()
|
||||
set(enum_filename "${relative_td_directory}/_${ARG_EXTENSION_NAME}_enum_gen.py")
|
||||
mlir_tablegen(${enum_filename} -gen-python-enum-bindings)
|
||||
mlir_tablegen(${enum_filename} -gen-python-enum-bindings -bind-dialect=${ARG_DIALECT_NAME})
|
||||
list(APPEND _sources ${enum_filename})
|
||||
endif()
|
||||
|
||||
|
||||
@ -58,11 +58,14 @@ public:
|
||||
bool loadDialectModule(std::string_view dialectNamespace);
|
||||
|
||||
/// Adds a user-friendly Attribute builder.
|
||||
/// Raises an exception if the mapping already exists and replace == false.
|
||||
/// Raises an exception if the mapping already exists and replace == false
|
||||
/// and allow_existing == false.
|
||||
/// Silently skips registration if allow_existing == true and the mapping
|
||||
/// already exists (first registration wins).
|
||||
/// This is intended to be called by implementation code.
|
||||
void registerAttributeBuilder(const std::string &attributeKind,
|
||||
nanobind::callable pyFunc,
|
||||
bool replace = false);
|
||||
nanobind::callable pyFunc, bool replace = false,
|
||||
bool allow_existing = false);
|
||||
|
||||
/// Adds a user-friendly type caster. Raises an exception if the mapping
|
||||
/// already exists and replace == false. This is intended to be called by
|
||||
|
||||
@ -1356,7 +1356,8 @@ struct MLIR_PYTHON_API_EXPORTED PyAttrBuilderMap {
|
||||
static nanobind::callable
|
||||
dunderGetItemNamed(const std::string &attributeKind);
|
||||
static void dunderSetItemNamed(const std::string &attributeKind,
|
||||
nanobind::callable func, bool replace);
|
||||
nanobind::callable func, bool replace,
|
||||
bool allow_existing);
|
||||
|
||||
static void bind(nanobind::module_ &m);
|
||||
};
|
||||
|
||||
@ -97,14 +97,27 @@ bool PyGlobals::loadDialectModule(std::string_view dialectNamespace) {
|
||||
}
|
||||
|
||||
void PyGlobals::registerAttributeBuilder(const std::string &attributeKind,
|
||||
nb::callable pyFunc, bool replace) {
|
||||
nb::callable pyFunc, bool replace,
|
||||
bool allowExisting) {
|
||||
nb::ft_lock_guard lock(mutex);
|
||||
nb::object &found = attributeBuilderMap[attributeKind];
|
||||
if (found && !replace) {
|
||||
throw std::runtime_error(
|
||||
if (found) {
|
||||
std::string msg =
|
||||
nanobind::detail::join("Attribute builder for '", attributeKind,
|
||||
"' is already registered with func: ",
|
||||
nb::cast<std::string>(nb::str(found))));
|
||||
nb::cast<std::string>(nb::str(found)));
|
||||
if (allowExisting) {
|
||||
#ifndef NDEBUG
|
||||
if (PyErr_WarnEx(PyExc_RuntimeWarning, msg.c_str(), 1) < 0) {
|
||||
// If the user has set warnings to errors (e.g., via -Werror),
|
||||
// PyErr_WarnEx returns -1 and sets a Python exception.
|
||||
throw nb::python_error();
|
||||
}
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
if (!replace)
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
found = std::move(pyFunc);
|
||||
}
|
||||
|
||||
@ -154,9 +154,10 @@ PyAttrBuilderMap::dunderGetItemNamed(const std::string &attributeKind) {
|
||||
}
|
||||
|
||||
void PyAttrBuilderMap::dunderSetItemNamed(const std::string &attributeKind,
|
||||
nb::callable func, bool replace) {
|
||||
nb::callable func, bool replace,
|
||||
bool allow_existing) {
|
||||
PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func),
|
||||
replace);
|
||||
replace, allow_existing);
|
||||
}
|
||||
|
||||
void PyAttrBuilderMap::bind(nb::module_ &m) {
|
||||
@ -171,6 +172,7 @@ void PyAttrBuilderMap::bind(nb::module_ &m) {
|
||||
"attribute kind.")
|
||||
.def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed,
|
||||
"attribute_kind"_a, "attr_builder"_a, "replace"_a = false,
|
||||
"allow_existing"_a = false,
|
||||
"Register an attribute builder for building MLIR "
|
||||
"attributes from Python values.");
|
||||
}
|
||||
|
||||
@ -95,9 +95,9 @@ def loc_tracebacks(*, max_depth: int | None = None) -> Generator[None, None, Non
|
||||
|
||||
|
||||
# Convenience decorator for registering user-friendly Attribute builders.
|
||||
def register_attribute_builder(kind, replace=False):
|
||||
def register_attribute_builder(kind, replace=False, allow_existing=False):
|
||||
def decorator_builder(func):
|
||||
AttrBuilder.insert(kind, func, replace=replace)
|
||||
AttrBuilder.insert(kind, func, replace=replace, allow_existing=allow_existing)
|
||||
return func
|
||||
|
||||
return decorator_builder
|
||||
|
||||
@ -35,7 +35,7 @@ def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two, NegOne]>
|
||||
// CHECK: return "negone"
|
||||
// CHECK: raise ValueError("Unknown MyEnum enum entry.")
|
||||
|
||||
// CHECK: @register_attribute_builder("MyEnum")
|
||||
// CHECK: @register_attribute_builder("MyEnum", allow_existing=True)
|
||||
// CHECK: def _myenum(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
@ -58,7 +58,7 @@ def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>
|
||||
// CHECK: return "two"
|
||||
// CHECK: raise ValueError("Unknown MyEnum64 enum entry.")
|
||||
|
||||
// CHECK: @register_attribute_builder("MyEnum64")
|
||||
// CHECK: @register_attribute_builder("MyEnum64", allow_existing=True)
|
||||
// CHECK: def _myenum64(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), int(x))
|
||||
|
||||
@ -102,14 +102,14 @@ def TestBitEnum_Attr : EnumAttr<Test_Dialect, TestBitEnum, "testbitenum">;
|
||||
// CHECK: return "any"
|
||||
// CHECK: raise ValueError("Unknown TestBitEnum enum entry.")
|
||||
|
||||
// CHECK: @register_attribute_builder("TestBitEnum")
|
||||
// CHECK: @register_attribute_builder("TestBitEnum", allow_existing=True)
|
||||
// CHECK: def _testbitenum(x, context):
|
||||
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), int(x))
|
||||
|
||||
// CHECK: @register_attribute_builder("TestBitEnum_Attr")
|
||||
// CHECK: @register_attribute_builder("TestDialect.TestBitEnum_Attr")
|
||||
// CHECK: def _testbitenum_attr(x, context):
|
||||
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<testbitenum {str(x)}>', context=context)
|
||||
|
||||
// CHECK: @register_attribute_builder("TestMyEnum_Attr")
|
||||
// CHECK: @register_attribute_builder("TestDialect.TestMyEnum_Attr")
|
||||
// CHECK: def _testmyenum_attr(x, context):
|
||||
// CHECK: return _ods_ir.Attribute.parse(f'#TestDialect<enum {str(x)}>', context=context)
|
||||
|
||||
@ -335,6 +335,11 @@ def testAffineIfWithElse():
|
||||
return
|
||||
|
||||
|
||||
@constructAndPrintInModule
|
||||
def test_double_AtomicRMWKindAttr_registration():
|
||||
from mlir.dialects import _affine_enum_gen
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAffineIfOpInsertionPoint
|
||||
@constructAndPrintInModule
|
||||
def testAffineIfOpInsertionPoint():
|
||||
|
||||
@ -94,7 +94,7 @@ def testCeilDivUOp(ctx):
|
||||
def testCmpOp(ctx):
|
||||
a = index.ConstantOp(value=42)
|
||||
b = index.ConstantOp(value=23)
|
||||
pred = AttrBuilder.get("IndexCmpPredicateAttr")("slt", context=ctx)
|
||||
pred = AttrBuilder.get("index.IndexCmpPredicateAttr")("slt", context=ctx)
|
||||
r = index.CmpOp(pred, lhs=a, rhs=b)
|
||||
# CHECK: %{{.*}} = index.cmp slt(%{{.*}}, %{{.*}})
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include "mlir/TableGen/Dialect.h"
|
||||
#include "mlir/TableGen/EnumInfo.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
@ -26,6 +27,10 @@ using llvm::formatv;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
|
||||
// Declared in OpPythonBindingGen.cpp; the two generators share the same
|
||||
// -bind-dialect option to allow filtering enum registrations by dialect.
|
||||
extern std::string dialectNameStorage;
|
||||
|
||||
/// File header and includes.
|
||||
constexpr const char *fileHeader = R"Py(
|
||||
# Autogenerated by mlir-tblgen; don't manually edit.
|
||||
@ -94,7 +99,11 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
|
||||
return false;
|
||||
|
||||
int64_t bitwidth = enumInfo.getBitwidth();
|
||||
os << formatv("@register_attribute_builder(\"{0}\")\n",
|
||||
// These builders may be emitted by multiple dialect enum_gen files when
|
||||
// dialects share enum definitions via .td includes. Use allow_existing=True
|
||||
// so that the first loaded dialect registers the builder and subsequent
|
||||
// loads silently skip (first-registration wins).
|
||||
os << formatv("@register_attribute_builder(\"{0}\", allow_existing=True)\n",
|
||||
enumAttrInfo->getAttrDefName());
|
||||
os << formatv("def _{0}(x, context):\n",
|
||||
enumAttrInfo->getAttrDefName().lower());
|
||||
@ -108,10 +117,12 @@ static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) {
|
||||
/// Emits an attribute builder for the given dialect enum attribute to support
|
||||
/// automatic conversion between enum values and attributes in Python. Returns
|
||||
/// `false` on success, `true` on failure.
|
||||
static bool emitDialectEnumAttributeBuilder(StringRef attrDefName,
|
||||
static bool emitDialectEnumAttributeBuilder(StringRef dialect,
|
||||
StringRef attrDefName,
|
||||
StringRef formatString,
|
||||
raw_ostream &os) {
|
||||
os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName);
|
||||
os << formatv("@register_attribute_builder(\"{0}.{1}\")\n", dialect,
|
||||
attrDefName);
|
||||
os << formatv("def _{0}(x, context):\n", attrDefName.lower());
|
||||
os << formatv(" return "
|
||||
"_ods_ir.Attribute.parse(f'{0}', context=context)\n\n",
|
||||
@ -132,6 +143,12 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
|
||||
for (const Record *it :
|
||||
records.getAllDerivedDefinitionsIfDefined("EnumAttr")) {
|
||||
AttrOrTypeDef attr(&*it);
|
||||
StringRef dialect = attr.getDialect().getName();
|
||||
// When -bind-dialect is specified, only emit builders for EnumAttr records
|
||||
// belonging to that dialect. This prevents duplicate registrations when
|
||||
// multiple dialects include the same .td files.
|
||||
if (!dialectNameStorage.empty() && dialect != dialectNameStorage)
|
||||
continue;
|
||||
if (!attr.getMnemonic()) {
|
||||
llvm::errs() << "enum case " << attr
|
||||
<< " needs mnemonic for python enum bindings generation";
|
||||
@ -139,14 +156,13 @@ static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) {
|
||||
}
|
||||
StringRef mnemonic = attr.getMnemonic().value();
|
||||
std::optional<StringRef> assemblyFormat = attr.getAssemblyFormat();
|
||||
StringRef dialect = attr.getDialect().getName();
|
||||
if (assemblyFormat == "`<` $value `>`") {
|
||||
emitDialectEnumAttributeBuilder(
|
||||
attr.getName(),
|
||||
dialect, attr.getName(),
|
||||
formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os);
|
||||
} else if (assemblyFormat == "$value") {
|
||||
emitDialectEnumAttributeBuilder(
|
||||
attr.getName(),
|
||||
dialect, attr.getName(),
|
||||
formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os);
|
||||
} else {
|
||||
llvm::errs()
|
||||
|
||||
@ -341,10 +341,13 @@ def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
|
||||
static llvm::cl::OptionCategory
|
||||
clOpPythonBindingCat("Options for -gen-python-op-bindings");
|
||||
|
||||
static llvm::cl::opt<std::string>
|
||||
std::string dialectNameStorage;
|
||||
|
||||
llvm::cl::opt<std::string, /*ExternalStorage=*/true>
|
||||
clDialectName("bind-dialect",
|
||||
llvm::cl::desc("The dialect to run the generator for"),
|
||||
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
||||
llvm::cl::location(dialectNameStorage),
|
||||
llvm::cl::cat(clOpPythonBindingCat));
|
||||
|
||||
static llvm::cl::opt<std::string> clDialectExtensionName(
|
||||
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
|
||||
@ -887,11 +890,27 @@ populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
|
||||
continue;
|
||||
}
|
||||
|
||||
// For EnumAttr-style attributes (those defined as EnumAttr<Dialect, ...>
|
||||
// in tablegen), use a dialect-qualified key ("dialect.AttrName") so the
|
||||
// lookup matches the registration emitted by EnumPythonBindingGen with
|
||||
// -bind-dialect. For all other attributes (plain attrs like I32Attr,
|
||||
// custom AttrDef, etc.), keep the unqualified name to match their
|
||||
// registrations in ir.py or dialect-specific Python files.
|
||||
Attribute baseAttr = attribute->attr.getBaseAttr();
|
||||
Dialect attrDialect = baseAttr.isSubClassOf("EnumAttr")
|
||||
? baseAttr.getDialect()
|
||||
: Dialect(nullptr);
|
||||
std::string attrBuilderKey = attrDialect
|
||||
? formatv("{0}.{1}", attrDialect.getName(),
|
||||
attribute->attr.getAttrDefName())
|
||||
.str()
|
||||
: attribute->attr.getAttrDefName().str();
|
||||
|
||||
builderLines.push_back(formatv(
|
||||
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
|
||||
? initOptionalAttributeWithBuilderTemplate
|
||||
: initAttributeWithBuilderTemplate,
|
||||
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
|
||||
argNames[i], attribute->name, attrBuilderKey));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1307,18 +1326,18 @@ static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
||||
/// headers and utilities. Returns `false` on success to comply with Tablegen
|
||||
/// registration requirements.
|
||||
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
|
||||
if (clDialectName.empty())
|
||||
if (dialectNameStorage.empty())
|
||||
llvm::PrintFatalError("dialect name not provided");
|
||||
|
||||
os << fileHeader;
|
||||
if (!clDialectExtensionName.empty())
|
||||
os << formatv(dialectExtensionTemplate, clDialectName.getValue());
|
||||
os << formatv(dialectExtensionTemplate, dialectNameStorage);
|
||||
else
|
||||
os << formatv(dialectClassTemplate, clDialectName.getValue());
|
||||
os << formatv(dialectClassTemplate, dialectNameStorage);
|
||||
|
||||
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
|
||||
Operator op(rec);
|
||||
if (op.getDialectName() == clDialectName.getValue())
|
||||
if (op.getDialectName() == dialectNameStorage)
|
||||
emitOpBindings(op, os);
|
||||
}
|
||||
return false;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user