//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to // generate the corresponding Python binding classes. // //===----------------------------------------------------------------------===// #include "OpGenHelpers.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/EnumInfo.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" using namespace mlir; using namespace mlir::tblgen; using llvm::formatv; using llvm::Record; using llvm::RecordKeeper; /// File header and includes. constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from enum import IntEnum, auto, IntFlag from ._ods_common import _cext as _ods_cext from ..ir import register_attribute_builder _ods_ir = _ods_cext.ir )Py"; /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. static std::string makePythonEnumCaseName(StringRef name) { if (isPythonReserved(name.str())) return (name + "_").str(); return name.str(); } /// Emits the Python class for the given enum. static void emitEnumClass(EnumInfo enumInfo, raw_ostream &os) { os << formatv("class {0}({1}):\n", enumInfo.getEnumClassName(), enumInfo.isBitEnum() ? "IntFlag" : "IntEnum"); if (!enumInfo.getSummary().empty()) os << formatv(" \"\"\"{0}\"\"\"\n", enumInfo.getSummary()); os << "\n"; for (const EnumCase &enumCase : enumInfo.getAllCases()) { os << formatv(" {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()), enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) : "auto()"); } os << "\n"; if (enumInfo.isBitEnum()) { os << formatv(" def __iter__(self):\n" " return iter([case for case in type(self) if " "(self & case) is case and self is not case])\n"); os << formatv(" def __len__(self):\n" " return bin(self).count(\"1\")\n"); os << "\n"; } os << formatv(" def __str__(self):\n"); if (enumInfo.isBitEnum()) os << formatv(" if len(self) > 1:\n" " return \"{0}\".join(map(str, self))\n", enumInfo.getDef().getValueAsString("separator")); for (const EnumCase &enumCase : enumInfo.getAllCases()) { os << formatv(" if self is {0}.{1}:\n", enumInfo.getEnumClassName(), makePythonEnumCaseName(enumCase.getSymbol())); os << formatv(" return \"{0}\"\n", enumCase.getStr()); } os << formatv(" raise ValueError(\"Unknown {0} enum entry.\")\n\n\n", enumInfo.getEnumClassName()); os << "\n"; } /// Emits an attribute builder for the given enum attribute to support automatic /// conversion between enum values and attributes in Python. Returns /// `false` on success, `true` on failure. static bool emitAttributeBuilder(const EnumInfo &enumInfo, raw_ostream &os) { std::optional enumAttrInfo = enumInfo.asEnumAttr(); if (!enumAttrInfo) return false; int64_t bitwidth = enumInfo.getBitwidth(); os << formatv("@register_attribute_builder(\"{0}\")\n", enumAttrInfo->getAttrDefName()); os << formatv("def _{0}(x, context):\n", enumAttrInfo->getAttrDefName().lower()); os << formatv(" return " "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " "context=context), int(x))\n\n", bitwidth); return false; } /// Emits an attribute builder for the given dialect enum attribute to support /// automatic conversion between enum values and attributes in Python. Returns /// `false` on success, `true` on failure. static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, StringRef formatString, raw_ostream &os) { os << formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); os << formatv("def _{0}(x, context):\n", attrDefName.lower()); os << formatv(" return " "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", formatString); return false; } /// Emits Python bindings for all enums in the record keeper. Returns /// `false` on success, `true` on failure. static bool emitPythonEnums(const RecordKeeper &records, raw_ostream &os) { os << fileHeader; for (const Record *it : records.getAllDerivedDefinitionsIfDefined("EnumInfo")) { EnumInfo enumInfo(*it); emitEnumClass(enumInfo, os); emitAttributeBuilder(enumInfo, os); } for (const Record *it : records.getAllDerivedDefinitionsIfDefined("EnumAttr")) { AttrOrTypeDef attr(&*it); if (!attr.getMnemonic()) { llvm::errs() << "enum case " << attr << " needs mnemonic for python enum bindings generation"; return true; } StringRef mnemonic = attr.getMnemonic().value(); std::optional assemblyFormat = attr.getAssemblyFormat(); StringRef dialect = attr.getDialect().getName(); if (assemblyFormat == "`<` $value `>`") { emitDialectEnumAttributeBuilder( attr.getName(), formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os); } else if (assemblyFormat == "$value") { emitDialectEnumAttributeBuilder( attr.getName(), formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os); } else { llvm::errs() << "unsupported assembly format for python enum bindings generation"; return true; } } return false; } // Registers the enum utility generator to mlir-tblgen. static mlir::GenRegistration genPythonEnumBindings("gen-python-enum-bindings", "Generate Python bindings for enum attributes", &emitPythonEnums);