llvm-project/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Maksim Levental a40f47c972
[mlir][python] automatic location inference (#151246)
This PR implements "automatic" location inference in the bindings. The
way it works is it walks the frame stack collecting source locations
(Python captures these in the frame itself). It is inspired by JAX's
[implementation](523ddcfbca/jax/_src/interpreters/mlir.py (L462))
but moves the frame stack traversal into the bindings for better
performance.

The system supports registering "included" and "excluded" filenames;
frames originating from functions in included filenames **will not** be
filtered and frames originating from functions in excluded filenames
**will** be filtered (in that order). This allows excluding all the
generated `*_ops_gen.py` files.

The system is also "toggleable" and off by default to save people who
have their own systems (such as JAX) from the added cost.

Note, the system stores the entire stacktrace (subject to
`locTracebackFramesLimit`) in the `Location` using specifically a
`CallSiteLoc`. This can be useful for profiling tools (flamegraphs
etc.).

Shoutout to the folks at JAX for coming up with a good system.

---------

Co-authored-by: Jacques Pienaar <jpienaar@google.com>
2025-08-12 16:59:59 -05:00

1074 lines
42 KiB
C++

//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
// binding classes wrapping a generic operation API.
//
//===----------------------------------------------------------------------===//
#include "OpGenHelpers.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;
/// File header and includes.
/// {0} is the dialect namespace.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_result_or_op_results as _get_op_result_or_op_results,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)
import builtins
from typing import Sequence as _Sequence, Union as _Union
)Py";
/// Template for dialect class:
/// {0} is the dialect namespace.
constexpr const char *dialectClassTemplate = R"Py(
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "{0}"
)Py";
constexpr const char *dialectExtensionTemplate = R"Py(
from ._{0}_ops_gen import _Dialect
)Py";
/// Template for operation class:
/// {0} is the Python class name;
/// {1} is the operation name.
constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
class {0}(_ods_ir.OpView):
OPERATION_NAME = "{1}"
)Py";
/// Template for class level declarations of operand and result
/// segment specs.
/// {0} is either "OPERAND" or "RESULT"
/// {1} is the segment spec
/// Each segment spec is either None (default) or an array of integers
/// where:
/// 1 = single element (expect non sequence operand/result)
/// 0 = optional element (expect a value or std::nullopt)
/// -1 = operand/result is a sequence corresponding to a variadic
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
_ODS_{0}_SEGMENTS = {1}
)Py";
/// Template for class level declarations of the _ODS_REGIONS spec:
/// {0} is the minimum number of regions
/// {1} is the Python bool literal for hasNoVariadicRegions
constexpr const char *opClassRegionSpecTemplate = R"Py(
_ODS_REGIONS = ({0}, {1})
)Py";
/// Template for single-element accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the position in the element list.
constexpr const char *opSingleTemplate = R"Py(
@builtins.property
def {0}(self):
return self.operation.{1}s[{2}]
)Py";
/// Template for single-element accessor after a variable-length group:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
/// This works for both a single variadic group (non-negative length) and an
/// single optional element (zero length if the element is absent).
constexpr const char *opSingleAfterVariableTemplate = R"Py(
@builtins.property
def {0}(self):
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
)Py";
/// Template for an optional element accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
/// This works if we have only one variable-length group (and it's the optional
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
@builtins.property
def {0}(self):
return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
)Py";
/// Template for the variadic group accessor in the single variadic group case:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
constexpr const char *opOneVariadicTemplate = R"Py(
@builtins.property
def {0}(self):
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
)Py";
/// First part of the template for equally-sized variadic group accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of non-variadic groups;
/// {3} is the total number of variadic groups;
/// {4} is the number of non-variadic groups preceding the current group;
/// {5} is the number of variadic groups preceding the current group.
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
@builtins.property
def {0}(self):
start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";
/// Second part of the template for equally-sized case, accessing a single
/// element:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
return self.operation.{0}s[start]
)Py";
/// Second part of the template for equally-sized case, accessing a variadic
/// group:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
return self.operation.{0}s[start:start + elements_per_group]
)Py";
/// Template for an attribute-sized group accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the position of the group in the group list;
/// {3} is a return suffix (expected [0] for single-element, empty for
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
constexpr const char *opVariadicSegmentTemplate = R"Py(
@builtins.property
def {0}(self):
{1}_range = _ods_segmented_accessor(
self.operation.{1}s,
self.operation.attributes["{1}SegmentSizes"], {2})
return {1}_range{3}
)Py";
/// Template for a suffix when accessing an optional element in the
/// attribute-sized case:
/// {0} is either 'operand' or 'result';
constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
R"Py([0] if len({0}_range) > 0 else None)Py";
/// Template for an operation attribute getter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeGetterTemplate = R"Py(
@builtins.property
def {0}(self):
return self.operation.attributes["{1}"]
)Py";
/// Template for an optional operation attribute getter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeGetterTemplate = R"Py(
@builtins.property
def {0}(self):
if "{1}" not in self.operation.attributes:
return None
return self.operation.attributes["{1}"]
)Py";
/// Template for a getter of a unit operation attribute, returns True of the
/// unit attribute is present, False otherwise (unit attributes have meaning
/// by mere presence):
/// {0} is the name of the attribute sanitized for Python,
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeGetterTemplate = R"Py(
@builtins.property
def {0}(self):
return "{1}" in self.operation.attributes
)Py";
/// Template for an operation attribute setter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["{1}"] = value
)Py";
/// Template for a setter of an optional operation attribute, setting to None
/// removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if value is not None:
self.operation.attributes["{1}"] = value
elif "{1}" in self.operation.attributes:
del self.operation.attributes["{1}"]
)Py";
/// Template for a setter of a unit operation attribute, setting to None or
/// False removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeSetterTemplate = R"Py(
@{0}.setter
def {0}(self, value):
if bool(value):
self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
elif "{1}" in self.operation.attributes:
del self.operation.attributes["{1}"]
)Py";
/// Template for a deleter of an optional or a unit operation attribute, removes
/// the attribute from the operation:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeDeleterTemplate = R"Py(
@{0}.deleter
def {0}(self):
del self.operation.attributes["{1}"]
)Py";
constexpr const char *regionAccessorTemplate = R"Py(
@builtins.property
def {0}(self):
return self.regions[{1}]
)Py";
constexpr const char *valueBuilderTemplate = R"Py(
def {0}({2}) -> {4}:
return {1}({3}){5}
)Py";
constexpr const char *valueBuilderVariadicTemplate = R"Py(
def {0}({2}) -> {4}:
return _get_op_result_or_op_results({1}({3}))
)Py";
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
static llvm::cl::opt<std::string>
clDialectName("bind-dialect",
llvm::cl::desc("The dialect to run the generator for"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
static llvm::cl::opt<std::string> clDialectExtensionName(
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
using AttributeClasses = DenseMap<StringRef, StringRef>;
/// Checks whether `str` would shadow a generated variable or attribute
/// part of the OpView API.
static bool isODSReserved(StringRef str) {
static llvm::StringSet<> reserved(
{"attributes", "create", "context", "ip", "operands", "print", "get_asm",
"loc", "verify", "regions", "results", "self", "operation",
"DIALECT_NAMESPACE", "OPERATION_NAME"});
return str.starts_with("_ods_") || str.ends_with("_ods") ||
reserved.contains(str);
}
/// Modifies the `name` in a way that it becomes suitable for Python bindings
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
std::string processedStr = name.str();
std::replace_if(
processedStr.begin(), processedStr.end(),
[](char c) { return !llvm::isAlnum(c); }, '_');
if (llvm::isDigit(*processedStr.begin()))
return "_" + processedStr;
if (isPythonReserved(processedStr) || isODSReserved(processedStr))
return processedStr + "_";
return processedStr;
}
static std::string attrSizedTraitForKind(const char *kind) {
return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
StringRef(kind).take_front().upper(),
StringRef(kind).drop_front());
}
/// Emits accessors to "elements" of an Op definition. Currently, the supported
/// elements are operands and results, indicated by `kind`, which must be either
/// `operand` or `result` and is used verbatim in the emitted code.
static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
unsigned numVariadicGroups, unsigned numElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
kind) &&
"unsupported kind");
// Traits indicating how to process variadic elements.
std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
StringRef(kind).take_front().upper(),
StringRef(kind).drop_front());
std::string attrSizedTrait = attrSizedTraitForKind(kind);
// If there is only one variable-length element group, its size can be
// inferred from the total number of elements. If there are none, the
// generation is straightforward.
if (numVariadicGroups <= 1) {
bool seenVariableLength = false;
for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isVariableLength())
seenVariableLength = true;
if (element.name.empty())
continue;
if (element.isVariableLength()) {
os << formatv(element.isOptional() ? opOneOptionalTemplate
: opOneVariadicTemplate,
sanitizeName(element.name), kind, numElements, i);
} else if (seenVariableLength) {
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
kind, numElements, i);
} else {
os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i);
}
}
return;
}
// Handle the operations where variadic groups have the same size.
if (op.getTrait(sameSizeTrait)) {
// Count the number of simple elements
unsigned numSimpleLength = 0;
for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (!element.isVariableLength()) {
++numSimpleLength;
}
}
// Generate the accessors
int numPrecedingSimple = 0;
int numPrecedingVariadic = 0;
for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (!element.name.empty()) {
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
kind, numSimpleLength, numVariadicGroups,
numPrecedingSimple, numPrecedingVariadic);
os << formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
kind);
}
if (element.isVariableLength())
++numPrecedingVariadic;
else
++numPrecedingSimple;
}
return;
}
// Handle the operations where the size of groups (variadic or not) is
// provided as an attribute. For non-variadic elements, make sure to return
// an element rather than a singleton container.
if (op.getTrait(attrSizedTrait)) {
for (unsigned i = 0; i < numElements; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.name.empty())
continue;
std::string trailing;
if (!element.isVariableLength())
trailing = "[0]";
else if (element.isOptional())
trailing = std::string(
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
i, trailing);
}
return;
}
llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
}
/// Free function helpers accessing Operator components.
static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
return op.getOperand(i);
}
static int getNumResults(const Operator &op) { return op.getNumResults(); }
static const NamedTypeConstraint &getResult(const Operator &op, int i) {
return op.getResult(i);
}
/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
getNumOperands(op), getOperand);
}
/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(),
getNumResults(op), getResult);
}
/// Emits accessors to Op attributes.
static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
for (const auto &namedAttr : op.getAttributes()) {
// Skip "derived" attributes because they are just C++ functions that we
// don't currently expose.
if (namedAttr.attr.isDerivedAttr())
continue;
if (namedAttr.name.empty())
continue;
std::string sanitizedName = sanitizeName(namedAttr.name);
// Unit attributes are handled specially.
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name);
os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name);
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
continue;
}
if (namedAttr.attr.isOptional()) {
os << formatv(optionalAttributeGetterTemplate, sanitizedName,
namedAttr.name);
os << formatv(optionalAttributeSetterTemplate, sanitizedName,
namedAttr.name);
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
} else {
os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name);
os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name);
// Non-optional attributes cannot be deleted.
}
}
}
/// Template for the default auto-generated builder.
/// {0} is a comma-separated list of builder arguments, including the trailing
/// `loc` and `ip`;
/// {1} is the code populating `operands`, `results` and `attributes`,
/// `successors` fields.
constexpr const char *initTemplate = R"Py(
def __init__(self, {0}):
operands = []
results = []
attributes = {{}
regions = None
{1}
super().__init__({2})
)Py";
/// Template for appending a single element to the operand/result list.
/// {0} is the field name.
constexpr const char *singleOperandAppendTemplate = "operands.append({0})";
constexpr const char *singleResultAppendTemplate = "results.append({0})";
/// Template for appending an optional element to the operand/result list.
/// {0} is the field name.
constexpr const char *optionalAppendOperandTemplate =
"if {0} is not None: operands.append({0})";
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
"operands.append({0})";
constexpr const char *optionalAppendResultTemplate =
"if {0} is not None: results.append({0})";
/// Template for appending a list of elements to the operand/result list.
/// {0} is the field name.
constexpr const char *multiOperandAppendTemplate =
"operands.extend(_get_op_results_or_values({0}))";
constexpr const char *multiOperandAppendPackTemplate =
"operands.append(_get_op_results_or_values({0}))";
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
isinstance({0}, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
isinstance({0}, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
_ods_get_default_loc_context(loc)))Py";
/// Template to initialize the successors list in the builder if there are any
/// successors.
/// {0} is the value to initialize the successors list to.
constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
/// Template to append or extend the list of successors in the builder.
/// {0} is the list method ('append' or 'extend');
/// {1} is the value to add.
constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
/// result types of the given operation.
static bool hasSameArgumentAndResultTypes(const Operator &op) {
return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
op.getNumVariableLengthResults() == 0;
}
/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
/// result types of the given operation.
static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
op.getNumVariableLengthResults() == 0;
}
/// Returns true if the InferTypeOpInterface can be used to infer result types
/// of the given operation.
static bool hasInferTypeInterface(const Operator &op) {
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
op.getNumRegions() == 0;
}
/// Returns true if there is a trait or interface that can be used to infer
/// result types of the given operation.
static bool canInferType(const Operator &op) {
return hasSameArgumentAndResultTypes(op) ||
hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
}
/// Populates `builderArgs` with result names if the builder is expected to
/// accept them as arguments.
static void
populateBuilderArgsResults(const Operator &op,
SmallVectorImpl<std::string> &builderArgs) {
if (canInferType(op))
return;
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
std::string name = op.getResultName(i).str();
if (name.empty()) {
if (op.getNumResults() == 1) {
// Special case for one result, make the default name be 'result'
// to properly match the built-in result accessor.
name = "result";
} else {
name = formatv("_gen_res_{0}", i);
}
}
name = sanitizeName(name);
builderArgs.push_back(name);
}
}
/// Populates `builderArgs` with the Python-compatible names of builder function
/// arguments using intermixed attributes and operands in the same order as they
/// appear in the `arguments` field of the op definition. Additionally,
/// `operandNames` is populated with names of operands in their order of
/// appearance.
static void populateBuilderArgs(const Operator &op,
SmallVectorImpl<std::string> &builderArgs,
SmallVectorImpl<std::string> &operandNames) {
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
std::string name = op.getArgName(i).str();
if (name.empty())
name = formatv("_gen_arg_{0}", i);
name = sanitizeName(name);
builderArgs.push_back(name);
if (!isa<NamedAttribute *>(op.getArg(i)))
operandNames.push_back(name);
}
}
/// Populates `builderArgs` with the Python-compatible names of builder function
/// successor arguments. Additionally, `successorArgNames` is also populated.
static void
populateBuilderArgsSuccessors(const Operator &op,
SmallVectorImpl<std::string> &builderArgs,
SmallVectorImpl<std::string> &successorArgNames) {
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
NamedSuccessor successor = op.getSuccessor(i);
std::string name = std::string(successor.name);
if (name.empty())
name = formatv("_gen_successor_{0}", i);
name = sanitizeName(name);
builderArgs.push_back(name);
successorArgNames.push_back(name);
}
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up operation attributes. `argNames` is expected to contain
/// the names of builder arguments that correspond to op arguments, i.e. to the
/// operands and attributes in the same order as they appear in the `arguments`
/// field.
static void
populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
SmallVectorImpl<std::string> &builderLines) {
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
if (!attribute)
continue;
// Unit attributes are handled specially.
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
builderLines.push_back(
formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
continue;
}
builderLines.push_back(formatv(
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
}
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up successors. successorArgNames is expected to correspond
/// to the Python argument name for each successor on the op.
static void
populateBuilderLinesSuccessors(const Operator &op,
ArrayRef<std::string> successorArgNames,
SmallVectorImpl<std::string> &builderLines) {
if (successorArgNames.empty()) {
builderLines.push_back(formatv(initSuccessorsTemplate, "None"));
return;
}
builderLines.push_back(formatv(initSuccessorsTemplate, "[]"));
for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
auto &argName = successorArgNames[i];
const NamedSuccessor &successor = op.getSuccessor(i);
builderLines.push_back(formatv(addSuccessorTemplate,
successor.isVariadic() ? "extend" : "append",
argName));
}
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op operands.
static void
populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
// For each element, find or generate a name.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
const NamedTypeConstraint &element = op.getOperand(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleOperandAppendTemplate;
} else if (element.isOptional()) {
if (sizedSegments) {
formatString = optionalAppendAttrSizedOperandsTemplate;
} else {
formatString = optionalAppendOperandTemplate;
}
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list-typed
// element. Otherwise, we extend the actual operands.
if (sizedSegments) {
formatString = multiOperandAppendPackTemplate;
} else {
formatString = multiOperandAppendTemplate;
}
}
builderLines.push_back(formatv(formatString.data(), name));
}
}
/// Python code template for deriving the operation result types from its
/// attribute:
/// - {0} is the name of the attribute from which to derive the types.
constexpr const char *deriveTypeFromAttrTemplate =
R"Py(_ods_result_type_source_attr = attributes["{0}"]
_ods_derived_result_type = (
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
_ods_result_type_source_attr.type))Py";
/// Python code template appending {0} type {1} times to the results list.
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// Appends the given multiline string as individual strings into
/// `builderLines`.
static void appendLineByLine(StringRef string,
SmallVectorImpl<std::string> &builderLines) {
std::pair<StringRef, StringRef> split = std::make_pair(string, string);
do {
split = split.second.split('\n');
builderLines.push_back(split.first.str());
} while (!split.second.empty());
}
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results.
static void
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
if (hasSameArgumentAndResultTypes(op)) {
builderLines.push_back(formatv(appendSameResultsTemplate,
"operands[0].type", op.getNumResults()));
return;
}
if (hasFirstAttrDerivedResultTypes(op)) {
const NamedAttribute &firstAttr = op.getAttribute(0);
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
"from which the type is derived");
appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(),
builderLines);
builderLines.push_back(formatv(appendSameResultsTemplate,
"_ods_derived_result_type",
op.getNumResults()));
return;
}
if (hasInferTypeInterface(op))
return;
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
const NamedTypeConstraint &element = op.getResult(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleResultAppendTemplate;
} else if (element.isOptional()) {
formatString = optionalAppendResultTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list-typed
// element. Otherwise, we extend the actual operands.
if (sizedSegments) {
formatString = singleResultAppendTemplate;
} else {
formatString = multiResultAppendTemplate;
}
}
builderLines.push_back(formatv(formatString.data(), name));
}
}
/// If the operation has variadic regions, adds a builder argument to specify
/// the number of those regions and builder lines to forward it to the generic
/// constructor.
static void populateBuilderRegions(const Operator &op,
SmallVectorImpl<std::string> &builderArgs,
SmallVectorImpl<std::string> &builderLines) {
if (op.hasNoVariadicRegions())
return;
// This is currently enforced when Operator is constructed.
assert(op.getNumVariadicRegions() == 1 &&
op.getRegion(op.getNumRegions() - 1).isVariadic() &&
"expected the last region to be varidic");
const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
std::string name =
("num_" + region.name.take_front().lower() + region.name.drop_front())
.str();
builderArgs.push_back(name);
builderLines.push_back(
formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
}
/// Emits a default builder constructing an operation from the list of its
/// result types, followed by a list of its operands. Returns vector
/// of fully built functionArgs for downstream users (to save having to
/// rebuild anew).
static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
SmallVector<std::string> builderArgs;
SmallVector<std::string> builderLines;
SmallVector<std::string> operandArgNames;
SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgsResults(op, builderArgs);
size_t numResultArgs = builderArgs.size();
populateBuilderArgs(op, builderArgs, operandArgNames);
size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
populateBuilderLinesOperand(op, operandArgNames, builderLines);
populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
builderLines);
populateBuilderLinesResult(
op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
populateBuilderRegions(op, builderArgs, builderLines);
// Layout of builderArgs vector elements:
// [ result_args operand_attr_args successor_args regions ]
// Determine whether the argument corresponding to a given index into the
// builderArgs vector is a python keyword argument or not.
auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
// All result, successor, and region arguments are positional arguments.
if ((builderArgIndex < numResultArgs) ||
(builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
return false;
// Keyword arguments:
// - optional named attributes (including unit attributes)
// - default-valued named attributes
// - optional operands
Argument a = op.getArg(builderArgIndex - numResultArgs);
if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
return ntype->isOptional();
return false;
};
// StringRefs in functionArgs refer to strings allocated by builderArgs.
SmallVector<StringRef> functionArgs;
// Add positional arguments.
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
if (!isKeywordArgFn(i))
functionArgs.push_back(builderArgs[i]);
}
// Add a bare '*' to indicate that all following arguments must be keyword
// arguments.
functionArgs.push_back("*");
// Add a default 'None' value to each keyword arg string, and then add to the
// function args list.
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
if (isKeywordArgFn(i)) {
builderArgs[i].append("=None");
functionArgs.push_back(builderArgs[i]);
}
}
functionArgs.push_back("loc=None");
functionArgs.push_back("ip=None");
SmallVector<std::string> initArgs;
initArgs.push_back("self.OPERATION_NAME");
initArgs.push_back("self._ODS_REGIONS");
initArgs.push_back("self._ODS_OPERAND_SEGMENTS");
initArgs.push_back("self._ODS_RESULT_SEGMENTS");
initArgs.push_back("attributes=attributes");
if (!hasInferTypeInterface(op))
initArgs.push_back("results=results");
initArgs.push_back("operands=operands");
initArgs.push_back("successors=_ods_successors");
initArgs.push_back("regions=regions");
initArgs.push_back("loc=loc");
initArgs.push_back("ip=ip");
os << formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "), llvm::join(initArgs, ", "));
return llvm::to_vector<8>(
llvm::map_range(functionArgs, [](StringRef s) { return s.str(); }));
}
static void emitSegmentSpec(
const Operator &op, const char *kind,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement,
raw_ostream &os) {
std::string segmentSpec("[");
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isOptional()) {
segmentSpec.append("0,");
} else if (element.isVariadic()) {
segmentSpec.append("-1,");
} else {
segmentSpec.append("1,");
}
}
segmentSpec.append("]");
os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
}
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
// Note that the base OpView class defines this as (0, True).
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
os << formatv(opClassRegionSpecTemplate, minRegionCount,
op.hasNoVariadicRegions() ? "True" : "False");
}
/// Emits named accessors to regions.
static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
for (const auto &en : llvm::enumerate(op.getRegions())) {
const NamedRegion &region = en.value();
if (region.name.empty())
continue;
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
"expected only the last region to be variadic");
os << formatv(regionAccessorTemplate, sanitizeName(region.name),
std::to_string(en.index()) +
(region.isVariadic() ? ":" : ""));
}
}
/// Emits builder that extracts results from op
static void emitValueBuilder(const Operator &op,
SmallVector<std::string> functionArgs,
raw_ostream &os) {
// Params with (possibly) default args.
auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
SmallVector<StringRef> argMaybeDefault =
llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
if (argMaybeDefault.size() == 2)
return arg + "=" + argMaybeDefault[1].str();
return arg;
});
// Actual args passed to op builder (e.g., opParam=op_param).
auto opBuilderArgs = llvm::map_range(
llvm::make_filter_range(functionArgs,
[](const std::string &s) { return s != "*"; }),
[](const std::string &arg) {
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
});
std::string nameWithoutDialect = sanitizeName(
op.getOperationName().substr(op.getOperationName().find('.') + 1));
if (nameWithoutDialect == op.getCppClassName())
nameWithoutDialect += "_";
std::string params = llvm::join(valueBuilderParams, ", ");
std::string args = llvm::join(opBuilderArgs, ", ");
const char *type =
(op.getNumResults() > 1
? "_Sequence[_ods_ir.Value]"
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"));
if (op.getNumVariableLengthResults() > 0) {
os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
op.getCppClassName(), params, args, type);
} else {
const char *results;
if (op.getNumResults() == 0) {
results = "";
} else if (op.getNumResults() == 1) {
results = ".result";
} else {
results = ".results";
}
os << formatv(valueBuilderTemplate, nameWithoutDialect,
op.getCppClassName(), params, args, type, results);
}
}
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName());
// Sized segments.
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
}
if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
}
emitRegionAttributes(op, os);
SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
emitAttributeAccessors(op, os);
emitResultAccessors(op, os);
emitRegionAccessors(op, os);
emitValueBuilder(op, functionArgs, os);
}
/// Emits bindings for the dialect specified in the command line, including file
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
os << fileHeader;
if (!clDialectExtensionName.empty())
os << formatv(dialectExtensionTemplate, clDialectName.getValue());
else
os << formatv(dialectClassTemplate, clDialectName.getValue());
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
if (op.getDialectName() == clDialectName.getValue())
emitOpBindings(op, os);
}
return false;
}
static GenRegistration
genPythonBindings("gen-python-op-bindings",
"Generate Python bindings for MLIR Ops", &emitAllOps);