For pyhon versions <3.12, pytype complains that: ``` error: in <module>: collections.abc.Buffer not supported yet [not-supported-yet] from collections.abc import Buffer as _Buffer ``` Since it seems like this code intends to support <3.12, disabling the type error on this line.
1471 lines
58 KiB
C++
1471 lines
58 KiB
C++
//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
|
|
// binding classes wrapping a generic operation API.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "OpGenHelpers.h"
|
|
|
|
#include "mlir/Support/IndentedOstream.h"
|
|
#include "mlir/TableGen/GenInfo.h"
|
|
#include "mlir/TableGen/Operator.h"
|
|
#include "llvm/ADT/SmallVectorExtras.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/TableGen/Error.h"
|
|
#include "llvm/TableGen/Record.h"
|
|
#include <regex>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tblgen;
|
|
using llvm::formatv;
|
|
using llvm::Record;
|
|
using llvm::RecordKeeper;
|
|
|
|
/// File header and includes.
|
|
/// {0} is the dialect namespace.
|
|
constexpr const char *fileHeader = R"Py(
|
|
# Autogenerated by mlir-tblgen; don't manually edit.
|
|
|
|
from ._ods_common import _cext as _ods_cext
|
|
from ._ods_common import (
|
|
equally_sized_accessor as _ods_equally_sized_accessor,
|
|
get_default_loc_context as _ods_get_default_loc_context,
|
|
get_op_results_or_values as _get_op_results_or_values,
|
|
segmented_accessor as _ods_segmented_accessor,
|
|
)
|
|
_ods_ir = _ods_cext.ir
|
|
_ods_cext.globals.register_traceback_file_exclusion(__file__)
|
|
|
|
import builtins
|
|
from typing import Any as _Any, Sequence as _Sequence, Union as _Union, Optional as _Optional
|
|
import sys as _sys
|
|
if _sys.version_info >= (3, 12):
|
|
from collections.abc import Buffer as _Buffer # pytype: disable=not-supported-yet
|
|
else:
|
|
try:
|
|
from typing_extensions import Buffer as _Buffer
|
|
except ImportError:
|
|
_Buffer = _Any
|
|
|
|
)Py";
|
|
|
|
/// Template for dialect class:
|
|
/// {0} is the dialect namespace.
|
|
constexpr const char *dialectClassTemplate = R"Py(
|
|
@_ods_cext.register_dialect
|
|
class _Dialect(_ods_ir.Dialect):
|
|
DIALECT_NAMESPACE = "{0}"
|
|
)Py";
|
|
|
|
constexpr const char *dialectExtensionTemplate = R"Py(
|
|
from ._{0}_ops_gen import _Dialect
|
|
)Py";
|
|
|
|
/// Template for operation class:
|
|
/// {0} is the Python class name;
|
|
/// {1} is the operation name;
|
|
/// {2} is the docstring for this operation.
|
|
constexpr const char *opClassTemplate = R"Py(
|
|
@_ods_cext.register_operation(_Dialect)
|
|
class {0}(_ods_ir.OpView):{2}
|
|
OPERATION_NAME = "{1}"
|
|
)Py";
|
|
|
|
/// Template for operation class:
|
|
/// {0} is the Python class name;
|
|
/// {1} is the operation name。
|
|
constexpr const char *opAdaptorClassTemplate = R"Py(
|
|
@_ods_cext.register_op_adaptor({0})
|
|
class {0}Adaptor(_ods_ir.OpAdaptor):
|
|
OPERATION_NAME = "{1}"
|
|
)Py";
|
|
|
|
/// Template for class level declarations of operand and result
|
|
/// segment specs.
|
|
/// {0} is either "OPERAND" or "RESULT"
|
|
/// {1} is the segment spec
|
|
/// Each segment spec is either None (default) or an array of integers
|
|
/// where:
|
|
/// 1 = single element (expect non sequence operand/result)
|
|
/// 0 = optional element (expect a value or std::nullopt)
|
|
/// -1 = operand/result is a sequence corresponding to a variadic
|
|
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
|
|
_ODS_{0}_SEGMENTS = {1}
|
|
)Py";
|
|
|
|
/// Template for class level declarations of the _ODS_REGIONS spec:
|
|
/// {0} is the minimum number of regions
|
|
/// {1} is the Python bool literal for hasNoVariadicRegions
|
|
constexpr const char *opClassRegionSpecTemplate = R"Py(
|
|
_ODS_REGIONS = ({0}, {1})
|
|
)Py";
|
|
|
|
/// Template for single-element accessor:
|
|
/// {0} is the name of the accessor;
|
|
/// {1} is either 'operand' or 'result';
|
|
/// {2} is the position in the element list.
|
|
/// {3} is the type hint.
|
|
constexpr const char *opSingleTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> {3}:
|
|
return self.{1}s[{2}]
|
|
)Py";
|
|
|
|
/// Template for single-element accessor after a variable-length group:
|
|
/// {0} is the name of the accessor;
|
|
/// {1} is either 'operand' or 'result';
|
|
/// {2} is the total number of element groups;
|
|
/// {3} is the position of the current group in the group list.
|
|
/// {4} is the type hint.
|
|
/// This works for both a single variadic group (non-negative length) and an
|
|
/// single optional element (zero length if the element is absent).
|
|
constexpr const char *opSingleAfterVariableTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> {4}:
|
|
_ods_variadic_group_length = len(self.{1}s) - {2} + 1
|
|
return self.{1}s[{3} + _ods_variadic_group_length - 1]
|
|
)Py";
|
|
|
|
/// Template for an optional element accessor:
|
|
/// {0} is the name of the accessor;
|
|
/// {1} is either 'operand' or 'result';
|
|
/// {2} is the total number of element groups;
|
|
/// {3} is the position of the current group in the group list.
|
|
/// {4} is the type hint.
|
|
/// This works if we have only one variable-length group (and it's the optional
|
|
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
|
|
/// smaller than the total number of groups.
|
|
constexpr const char *opOneOptionalTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> _Optional[{4}]:
|
|
return None if len(self.{1}s) < {2} else self.{1}s[{3}]
|
|
)Py";
|
|
|
|
/// Template for the variadic group accessor in the single variadic group case:
|
|
/// {0} is the name of the accessor;
|
|
/// {1} is either 'operand' or 'result';
|
|
/// {2} is the total number of element groups;
|
|
/// {3} is the position of the current group in the group list.
|
|
/// {4} is the type hint.
|
|
constexpr const char *opOneVariadicTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> {4}:
|
|
_ods_variadic_group_length = len(self.{1}s) - {2} + 1
|
|
return self.{1}s[{3}:{3} + _ods_variadic_group_length]
|
|
)Py";
|
|
|
|
/// First part of the template for equally-sized variadic group accessor:
|
|
/// {0} is the name of the accessor;
|
|
/// {1} is either 'operand' or 'result';
|
|
/// {2} is the total number of non-variadic groups;
|
|
/// {3} is the total number of variadic groups;
|
|
/// {4} is the number of non-variadic groups preceding the current group;
|
|
/// {5} is the number of variadic groups preceding the current group.
|
|
/// {6} is the type hint.
|
|
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> {6}:
|
|
start, elements_per_group = _ods_equally_sized_accessor(self.{1}s, {2}, {3}, {4}, {5}))Py";
|
|
|
|
/// Second part of the template for equally-sized case, accessing a single
|
|
/// element:
|
|
/// {0} is either 'operand' or 'result'.
|
|
constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
|
|
return self.{0}s[start]
|
|
)Py";
|
|
|
|
/// Second part of the template for equally-sized case, accessing a variadic
|
|
/// group:
|
|
/// {0} is either 'operand' or 'result'.
|
|
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
|
|
return self.{0}s[start:start + elements_per_group]
|
|
)Py";
|
|
|
|
/// Template for an attribute-sized group accessor:
|
|
/// {0} is the name of the accessor;
|
|
/// {1} is either 'operand' or 'result';
|
|
/// {2} is the position of the group in the group list;
|
|
/// {3} is a return suffix (expected [0] for single-element, empty for
|
|
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional);
|
|
/// {4} is the type hint;
|
|
/// {5} is the instance variable name in python;
|
|
/// {6} is the instance variable name for attributes in python.
|
|
constexpr const char *opVariadicSegmentTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> {4}:
|
|
{1}_range = _ods_segmented_accessor(
|
|
self.{5}s,
|
|
self.{6}["{1}SegmentSizes"], {2})
|
|
return {1}_range{3}
|
|
)Py";
|
|
|
|
/// Template for a suffix when accessing an optional element in the
|
|
/// attribute-sized case:
|
|
/// {0} is either 'operand' or 'result';
|
|
constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
|
|
R"Py([0] if len({0}_range) > 0 else None)Py";
|
|
|
|
/// Template for an operation attribute getter:
|
|
/// {0} is the name of the attribute sanitized for Python;
|
|
/// {1} is the original name of the attribute.
|
|
/// {2} is the type hint.
|
|
constexpr const char *attributeGetterTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> {2}:
|
|
return self.operation.attributes["{1}"]
|
|
)Py";
|
|
|
|
/// Template for an optional operation attribute getter:
|
|
/// {0} is the name of the attribute sanitized for Python;
|
|
/// {1} is the original name of the attribute.
|
|
/// {2} is the type hint.
|
|
constexpr const char *optionalAttributeGetterTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> _Optional[{2}]:
|
|
if "{1}" not in self.operation.attributes:
|
|
return None
|
|
return self.operation.attributes["{1}"]
|
|
)Py";
|
|
|
|
/// Template for an operation attribute getter for adaptors:
|
|
/// {0} is the name of the attribute sanitized for Python;
|
|
/// {1} is the original name of the attribute.
|
|
/// {2} is the type hint.
|
|
constexpr const char *adaptorAttributeGetterTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> {2}:
|
|
return self.attributes["{1}"]
|
|
)Py";
|
|
|
|
/// Template for an optional operation attribute getter for adaptors:
|
|
/// {0} is the name of the attribute sanitized for Python;
|
|
/// {1} is the original name of the attribute.
|
|
/// {2} is the type hint.
|
|
constexpr const char *adaptorOptionalAttributeGetterTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> _Optional[{2}]:
|
|
if "{1}" not in self.attributes:
|
|
return None
|
|
return self.attributes["{1}"]
|
|
)Py";
|
|
|
|
/// Template for a getter of a unit operation attribute, returns True of the
|
|
/// unit attribute is present, False otherwise (unit attributes have meaning
|
|
/// by mere presence):
|
|
/// {0} is the name of the attribute sanitized for Python,
|
|
/// {1} is the original name of the attribute.
|
|
constexpr const char *unitAttributeGetterTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> bool:
|
|
return "{1}" in self.operation.attributes
|
|
)Py";
|
|
|
|
/// Template for a getter of a unit operation attribute for adaptors, returns
|
|
/// True of the unit attribute is present, False otherwise (unit attributes have
|
|
/// meaning by mere presence):
|
|
/// {0} is the name of the attribute sanitized for Python,
|
|
/// {1} is the original name of the attribute.
|
|
constexpr const char *adaptorUnitAttributeGetterTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> bool:
|
|
return "{1}" in self.attributes
|
|
)Py";
|
|
|
|
/// Template for an operation attribute setter:
|
|
/// {0} is the name of the attribute sanitized for Python;
|
|
/// {1} is the original name of the attribute.
|
|
/// {2} is the type hint.
|
|
constexpr const char *attributeSetterTemplate = R"Py(
|
|
@{0}.setter
|
|
def {0}(self, value: {2}):
|
|
if value is None:
|
|
raise ValueError("'None' not allowed as value for mandatory attributes")
|
|
self.operation.attributes["{1}"] = value
|
|
)Py";
|
|
|
|
/// Template for a setter of an optional operation attribute, setting to None
|
|
/// removes the attribute:
|
|
/// {0} is the name of the attribute sanitized for Python;
|
|
/// {1} is the original name of the attribute.
|
|
/// {2} is the type hint.
|
|
constexpr const char *optionalAttributeSetterTemplate = R"Py(
|
|
@{0}.setter
|
|
def {0}(self, value: _Optional[{2}]):
|
|
if value is not None:
|
|
self.operation.attributes["{1}"] = value
|
|
elif "{1}" in self.operation.attributes:
|
|
del self.operation.attributes["{1}"]
|
|
)Py";
|
|
|
|
/// Template for a setter of a unit operation attribute, setting to None or
|
|
/// False removes the attribute:
|
|
/// {0} is the name of the attribute sanitized for Python;
|
|
/// {1} is the original name of the attribute.
|
|
constexpr const char *unitAttributeSetterTemplate = R"Py(
|
|
@{0}.setter
|
|
def {0}(self, value):
|
|
if bool(value):
|
|
self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
|
|
elif "{1}" in self.operation.attributes:
|
|
del self.operation.attributes["{1}"]
|
|
)Py";
|
|
|
|
/// Template for a deleter of an optional or a unit operation attribute, removes
|
|
/// the attribute from the operation:
|
|
/// {0} is the name of the attribute sanitized for Python;
|
|
/// {1} is the original name of the attribute.
|
|
constexpr const char *attributeDeleterTemplate = R"Py(
|
|
@{0}.deleter
|
|
def {0}(self):
|
|
del self.operation.attributes["{1}"]
|
|
)Py";
|
|
|
|
constexpr const char *regionAccessorTemplate = R"Py(
|
|
@builtins.property
|
|
def {0}(self) -> {2}:
|
|
return self.regions[{1}]
|
|
)Py";
|
|
|
|
constexpr const char *valueBuilderTemplate = R"Py(
|
|
def {0}({2}) -> {4}:
|
|
return {1}({3}){5}
|
|
)Py";
|
|
|
|
constexpr const char *valueBuilderVariadicTemplate = R"Py(
|
|
def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
|
|
op = {1}({3}); results = op.results
|
|
return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
|
|
)Py";
|
|
|
|
static llvm::cl::OptionCategory
|
|
clOpPythonBindingCat("Options for -gen-python-op-bindings");
|
|
|
|
std::string dialectNameStorage;
|
|
|
|
llvm::cl::opt<std::string, /*ExternalStorage=*/true>
|
|
clDialectName("bind-dialect",
|
|
llvm::cl::desc("The dialect to run the generator for"),
|
|
llvm::cl::location(dialectNameStorage),
|
|
llvm::cl::cat(clOpPythonBindingCat));
|
|
|
|
static llvm::cl::opt<std::string> clDialectExtensionName(
|
|
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
|
|
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
|
|
|
|
using AttributeClasses = DenseMap<StringRef, StringRef>;
|
|
|
|
/// Checks whether `str` would shadow a generated variable or attribute
|
|
/// part of the OpView API.
|
|
static bool isODSReserved(StringRef str) {
|
|
static llvm::StringSet<> reserved(
|
|
{"attributes", "create", "context", "ip", "operands", "print", "get_asm",
|
|
"loc", "verify", "regions", "results", "self", "operation",
|
|
"DIALECT_NAMESPACE", "OPERATION_NAME"});
|
|
return str.starts_with("_ods_") || str.ends_with("_ods") ||
|
|
reserved.contains(str);
|
|
}
|
|
|
|
/// Modifies the `name` in a way that it becomes suitable for Python bindings
|
|
/// (does not change the `name` if it already is suitable) and returns the
|
|
/// modified version.
|
|
static std::string sanitizeName(StringRef name) {
|
|
std::string processedStr = name.str();
|
|
std::replace_if(
|
|
processedStr.begin(), processedStr.end(),
|
|
[](char c) { return !llvm::isAlnum(c); }, '_');
|
|
|
|
if (llvm::isDigit(*processedStr.begin()))
|
|
return "_" + processedStr;
|
|
|
|
if (isPythonReserved(processedStr) || isODSReserved(processedStr))
|
|
return processedStr + "_";
|
|
return processedStr;
|
|
}
|
|
|
|
static std::string attrSizedTraitForKind(const char *kind) {
|
|
return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
|
|
StringRef(kind).take_front().upper(),
|
|
StringRef(kind).drop_front());
|
|
}
|
|
|
|
static StringRef getPythonType(StringRef cppType) {
|
|
return llvm::StringSwitch<StringRef>(cppType)
|
|
.Case("::mlir::MemRefType", "_ods_ir.MemRefType")
|
|
.Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
|
|
.Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
|
|
.Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
|
|
.Case("::mlir::VectorType", "_ods_ir.VectorType")
|
|
.Case("::mlir::IntegerType", "_ods_ir.IntegerType")
|
|
.Case("::mlir::FloatType", "_ods_ir.FloatType")
|
|
.Case("::mlir::IndexType", "_ods_ir.IndexType")
|
|
.Case("::mlir::ComplexType", "_ods_ir.ComplexType")
|
|
.Case("::mlir::TupleType", "_ods_ir.TupleType")
|
|
.Case("::mlir::NoneType", "_ods_ir.NoneType")
|
|
.Default(StringRef());
|
|
}
|
|
|
|
/// Emits accessors to "elements" of an Op definition. Currently, the supported
|
|
/// elements are operands and results, indicated by `kind`, which must be either
|
|
/// `operand` or `result` and is used verbatim in the emitted code.
|
|
static void emitElementAccessors(
|
|
const Operator &op, raw_ostream &os, const char *kind,
|
|
unsigned numVariadicGroups, unsigned numElements,
|
|
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
|
getElement,
|
|
bool isAdaptor = false) {
|
|
assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
|
|
kind) &&
|
|
"unsupported kind");
|
|
|
|
// Traits indicating how to process variadic elements.
|
|
std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
|
|
StringRef(kind).take_front().upper(),
|
|
StringRef(kind).drop_front());
|
|
std::string attrSizedTrait = attrSizedTraitForKind(kind);
|
|
|
|
std::string pyAttrName = isAdaptor ? kind : std::string("operation.") + kind;
|
|
|
|
// If there is only one variable-length element group, its size can be
|
|
// inferred from the total number of elements. If there are none, the
|
|
// generation is straightforward.
|
|
if (numVariadicGroups <= 1) {
|
|
bool seenVariableLength = false;
|
|
for (unsigned i = 0; i < numElements; ++i) {
|
|
const NamedTypeConstraint &element = getElement(op, i);
|
|
if (element.isVariableLength())
|
|
seenVariableLength = true;
|
|
if (element.name.empty())
|
|
continue;
|
|
std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
|
|
: "_ods_ir.OpResult";
|
|
if (StringRef pythonType = getPythonType(element.constraint.getCppType());
|
|
!pythonType.empty())
|
|
type = llvm::formatv("{0}[{1}]", type, pythonType);
|
|
if (element.isVariableLength()) {
|
|
if (element.isOptional()) {
|
|
os << formatv(opOneOptionalTemplate, sanitizeName(element.name),
|
|
pyAttrName, numElements, i, type);
|
|
} else {
|
|
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
|
|
: "_ods_ir.OpResultList";
|
|
if (StringRef pythonType =
|
|
getPythonType(element.constraint.getCppType());
|
|
!pythonType.empty())
|
|
type = llvm::formatv("{0}[{1}]", type, pythonType);
|
|
os << formatv(opOneVariadicTemplate, sanitizeName(element.name),
|
|
pyAttrName, numElements, i, type);
|
|
}
|
|
} else if (seenVariableLength) {
|
|
os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
|
|
pyAttrName, numElements, i, type);
|
|
} else {
|
|
os << formatv(opSingleTemplate, sanitizeName(element.name), pyAttrName,
|
|
i, type);
|
|
}
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Handle the operations where variadic groups have the same size.
|
|
if (op.getTrait(sameSizeTrait)) {
|
|
// Count the number of simple elements
|
|
unsigned numSimpleLength = 0;
|
|
for (unsigned i = 0; i < numElements; ++i) {
|
|
const NamedTypeConstraint &element = getElement(op, i);
|
|
if (!element.isVariableLength()) {
|
|
++numSimpleLength;
|
|
}
|
|
}
|
|
|
|
// Generate the accessors
|
|
int numPrecedingSimple = 0;
|
|
int numPrecedingVariadic = 0;
|
|
for (unsigned i = 0; i < numElements; ++i) {
|
|
const NamedTypeConstraint &element = getElement(op, i);
|
|
if (!element.name.empty()) {
|
|
std::string type;
|
|
if (element.isVariableLength()) {
|
|
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
|
|
: "_ods_ir.OpResultList";
|
|
} else {
|
|
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
|
|
: "_ods_ir.OpResult";
|
|
}
|
|
if (StringRef pythonType =
|
|
getPythonType(element.constraint.getCppType());
|
|
!pythonType.empty()) {
|
|
type = llvm::formatv("{0}[{1}]", type, pythonType);
|
|
}
|
|
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
|
|
pyAttrName, numSimpleLength, numVariadicGroups,
|
|
numPrecedingSimple, numPrecedingVariadic, type);
|
|
os << formatv(element.isVariableLength()
|
|
? opVariadicEqualVariadicTemplate
|
|
: opVariadicEqualSimpleTemplate,
|
|
pyAttrName);
|
|
}
|
|
if (element.isVariableLength())
|
|
++numPrecedingVariadic;
|
|
else
|
|
++numPrecedingSimple;
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Handle the operations where the size of groups (variadic or not) is
|
|
// provided as an attribute. For non-variadic elements, make sure to return
|
|
// an element rather than a singleton container.
|
|
if (op.getTrait(attrSizedTrait)) {
|
|
for (unsigned i = 0; i < numElements; ++i) {
|
|
const NamedTypeConstraint &element = getElement(op, i);
|
|
if (element.name.empty())
|
|
continue;
|
|
std::string trailing;
|
|
std::string type = std::strcmp(kind, "operand") == 0
|
|
? "_ods_ir.OpOperandList"
|
|
: "_ods_ir.OpResultList";
|
|
if (!element.isVariableLength() || element.isOptional()) {
|
|
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
|
|
: "_ods_ir.OpResult";
|
|
if (StringRef pythonType =
|
|
getPythonType(element.constraint.getCppType());
|
|
!pythonType.empty()) {
|
|
type = llvm::formatv("{0}[{1}]", type, pythonType);
|
|
}
|
|
if (!element.isVariableLength()) {
|
|
trailing = "[0]";
|
|
} else if (element.isOptional()) {
|
|
type = "_Optional[" + type + "]";
|
|
trailing = std::string(
|
|
formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
|
|
}
|
|
} else {
|
|
if (StringRef pythonType =
|
|
getPythonType(element.constraint.getCppType());
|
|
!pythonType.empty()) {
|
|
type = llvm::formatv("{0}[{1}]", type, pythonType);
|
|
}
|
|
}
|
|
|
|
os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
|
|
i, trailing, type, pyAttrName,
|
|
isAdaptor ? "attributes" : "operation.attributes");
|
|
}
|
|
return;
|
|
}
|
|
|
|
llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
|
|
}
|
|
|
|
/// Free function helpers accessing Operator components.
|
|
static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
|
|
static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
|
|
return op.getOperand(i);
|
|
}
|
|
static int getNumResults(const Operator &op) { return op.getNumResults(); }
|
|
static const NamedTypeConstraint &getResult(const Operator &op, int i) {
|
|
return op.getResult(i);
|
|
}
|
|
|
|
/// Emits accessors to Op operands.
|
|
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
|
|
emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
|
|
getNumOperands(op), getOperand);
|
|
}
|
|
|
|
/// Emits accessors Op results.
|
|
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
|
|
emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(),
|
|
getNumResults(op), getResult);
|
|
}
|
|
|
|
static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
|
|
auto storageTypeStr = attr.getStorageType();
|
|
if (storageTypeStr == "::mlir::AffineMapAttr")
|
|
return "AffineMapAttr";
|
|
if (storageTypeStr == "::mlir::ArrayAttr")
|
|
return "ArrayAttr";
|
|
if (storageTypeStr == "::mlir::BoolAttr")
|
|
return "BoolAttr";
|
|
if (storageTypeStr == "::mlir::DenseBoolArrayAttr")
|
|
return "DenseBoolArrayAttr";
|
|
if (storageTypeStr == "::mlir::DenseElementsAttr") {
|
|
llvm::StringSet<> superClasses;
|
|
for (const Record *sc : attr.getDef().getSuperClasses())
|
|
superClasses.insert(sc->getNameInitAsString());
|
|
if (superClasses.contains("FloatElementsAttr") ||
|
|
superClasses.contains("RankedFloatElementsAttr")) {
|
|
return "DenseFPElementsAttr";
|
|
}
|
|
return "DenseElementsAttr";
|
|
}
|
|
if (storageTypeStr == "::mlir::DenseF32ArrayAttr")
|
|
return "DenseF32ArrayAttr";
|
|
if (storageTypeStr == "::mlir::DenseF64ArrayAttr")
|
|
return "DenseF64ArrayAttr";
|
|
if (storageTypeStr == "::mlir::DenseFPElementsAttr")
|
|
return "DenseFPElementsAttr";
|
|
if (storageTypeStr == "::mlir::DenseI16ArrayAttr")
|
|
return "DenseI16ArrayAttr";
|
|
if (storageTypeStr == "::mlir::DenseI32ArrayAttr")
|
|
return "DenseI32ArrayAttr";
|
|
if (storageTypeStr == "::mlir::DenseI64ArrayAttr")
|
|
return "DenseI64ArrayAttr";
|
|
if (storageTypeStr == "::mlir::DenseI8ArrayAttr")
|
|
return "DenseI8ArrayAttr";
|
|
if (storageTypeStr == "::mlir::DenseIntElementsAttr")
|
|
return "DenseIntElementsAttr";
|
|
if (storageTypeStr == "::mlir::DenseResourceElementsAttr")
|
|
return "DenseResourceElementsAttr";
|
|
if (storageTypeStr == "::mlir::DictionaryAttr")
|
|
return "DictAttr";
|
|
if (storageTypeStr == "::mlir::FlatSymbolRefAttr")
|
|
return "FlatSymbolRefAttr";
|
|
if (storageTypeStr == "::mlir::FloatAttr")
|
|
return "FloatAttr";
|
|
if (storageTypeStr == "::mlir::IntegerAttr") {
|
|
if (attr.getAttrDefName().str() == "I1Attr")
|
|
return "BoolAttr";
|
|
return "IntegerAttr";
|
|
}
|
|
if (storageTypeStr == "::mlir::IntegerSetAttr")
|
|
return "IntegerSetAttr";
|
|
if (storageTypeStr == "::mlir::OpaqueAttr")
|
|
return "OpaqueAttr";
|
|
if (storageTypeStr == "::mlir::StridedLayoutAttr")
|
|
return "StridedLayoutAttr";
|
|
if (storageTypeStr == "::mlir::StringAttr")
|
|
return "StringAttr";
|
|
if (storageTypeStr == "::mlir::SymbolRefAttr")
|
|
return "SymbolRefAttr";
|
|
if (storageTypeStr == "::mlir::TypeAttr")
|
|
return "TypeAttr";
|
|
if (storageTypeStr == "::mlir::UnitAttr")
|
|
return "UnitAttr";
|
|
return "Attribute";
|
|
}
|
|
|
|
/// Returns the Python raw value type accepted by the AttrBuilder for the given
|
|
/// attribute. Returns empty StringRef if no mapping is known.
|
|
static StringRef getPythonAttrRawType(mlir::tblgen::Attribute attr) {
|
|
return llvm::StringSwitch<StringRef>(attr.getAttrDefName())
|
|
.Cases({"BoolAttr", "I1Attr"}, "bool")
|
|
.Cases({"I8Attr", "I16Attr", "I32Attr", "I64Attr"}, "int")
|
|
.Cases({"SI1Attr", "SI8Attr", "SI16Attr", "SI32Attr", "SI64Attr"}, "int")
|
|
.Cases({"UI1Attr", "UI8Attr", "UI16Attr", "UI32Attr", "UI64Attr"}, "int")
|
|
.Case("IndexAttr", "int")
|
|
.Cases({"F32Attr", "F64Attr"}, "float")
|
|
.Cases({"StrAttr", "SymbolNameAttr"}, "str")
|
|
.Cases({"FlatSymbolRefAttr", "SymbolRefAttr"}, "str")
|
|
.Case("TypeAttr", "_ods_ir.Type")
|
|
.Case("AffineMapAttr", "_ods_ir.AffineMap")
|
|
.Case("IntegerSetAttr", "_ods_ir.IntegerSet")
|
|
.Case("DictionaryAttr", "dict")
|
|
.Case("ArrayAttr", "_Sequence[_ods_ir.Attribute]")
|
|
.Cases({"I32ArrayAttr", "I64ArrayAttr", "I64SmallVectorArrayAttr"},
|
|
"_Sequence[int]")
|
|
.Cases({"F32ArrayAttr", "F64ArrayAttr"}, "_Sequence[float]")
|
|
.Cases({"BoolArrayAttr", "DenseBoolArrayAttr"}, "_Sequence[bool]")
|
|
.Cases({"StrArrayAttr", "FlatSymbolRefArrayAttr"}, "_Sequence[str]")
|
|
.Cases({"DenseI8ArrayAttr", "DenseI16ArrayAttr", "DenseI32ArrayAttr",
|
|
"DenseI64ArrayAttr"},
|
|
"_Sequence[int]")
|
|
.Cases({"DenseF32ArrayAttr", "DenseF64ArrayAttr"}, "_Sequence[float]")
|
|
.Cases({"I32ElementsAttr", "I64ElementsAttr", "IndexElementsAttr"},
|
|
"_Union[_Sequence[int], _Buffer]")
|
|
.Case("F64ElementsAttr", "_Union[_Sequence[float], _Buffer]")
|
|
.Default(StringRef());
|
|
}
|
|
|
|
/// Emits accessors to Op attributes.
|
|
static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
|
|
for (const auto &namedAttr : op.getAttributes()) {
|
|
// Skip "derived" attributes because they are just C++ functions that we
|
|
// don't currently expose.
|
|
if (namedAttr.attr.isDerivedAttr())
|
|
continue;
|
|
|
|
if (namedAttr.name.empty())
|
|
continue;
|
|
|
|
std::string sanitizedName = sanitizeName(namedAttr.name);
|
|
|
|
// Unit attributes are handled specially.
|
|
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
|
os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name);
|
|
os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name);
|
|
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
|
|
continue;
|
|
}
|
|
|
|
std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr);
|
|
if (namedAttr.attr.isOptional()) {
|
|
os << formatv(optionalAttributeGetterTemplate, sanitizedName,
|
|
namedAttr.name, type);
|
|
os << formatv(optionalAttributeSetterTemplate, sanitizedName,
|
|
namedAttr.name, type);
|
|
os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
|
|
} else {
|
|
os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name,
|
|
type);
|
|
os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name,
|
|
type);
|
|
// Non-optional attributes cannot be deleted.
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Emits accessors to Op attributes for adaptors.
|
|
static void emitAdaptorAttributeAccessors(const Operator &op, raw_ostream &os) {
|
|
for (const auto &namedAttr : op.getAttributes()) {
|
|
// Skip "derived" attributes because they are just C++ functions that we
|
|
// don't currently expose.
|
|
if (namedAttr.attr.isDerivedAttr())
|
|
continue;
|
|
|
|
if (namedAttr.name.empty())
|
|
continue;
|
|
|
|
std::string sanitizedName = sanitizeName(namedAttr.name);
|
|
|
|
// Unit attributes are handled specially.
|
|
if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
|
os << formatv(adaptorUnitAttributeGetterTemplate, sanitizedName,
|
|
namedAttr.name);
|
|
continue;
|
|
}
|
|
|
|
std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr);
|
|
os << formatv(namedAttr.attr.isOptional()
|
|
? adaptorOptionalAttributeGetterTemplate
|
|
: adaptorAttributeGetterTemplate,
|
|
sanitizedName, namedAttr.name, type);
|
|
}
|
|
}
|
|
|
|
/// Template for the default auto-generated builder.
|
|
/// {0} is a comma-separated list of builder arguments, including the trailing
|
|
/// `loc` and `ip`;
|
|
/// {1} is the code populating `operands`, `results` and `attributes`,
|
|
/// `successors` fields.
|
|
constexpr const char *initTemplate = R"Py(
|
|
def __init__(self, {0}):
|
|
operands = []
|
|
attributes = {{}
|
|
regions = None
|
|
{1}
|
|
super().__init__({2})
|
|
)Py";
|
|
|
|
/// Template for appending a single element to the operand/result list.
|
|
/// {0} is the field name.
|
|
constexpr const char *singleOperandAppendTemplate = "operands.append({0})";
|
|
constexpr const char *singleResultAppendTemplate = "results.append({0})";
|
|
|
|
/// Template for appending an optional element to the operand/result list.
|
|
/// {0} is the field name.
|
|
constexpr const char *optionalAppendOperandTemplate =
|
|
"if {0} is not None: operands.append({0})";
|
|
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
|
|
"operands.append({0})";
|
|
constexpr const char *optionalAppendResultTemplate =
|
|
"if {0} is not None: results.append({0})";
|
|
|
|
/// Template for appending a list of elements to the operand/result list.
|
|
/// {0} is the field name.
|
|
constexpr const char *multiOperandAppendTemplate =
|
|
"operands.extend(_get_op_results_or_values({0}))";
|
|
constexpr const char *multiOperandAppendPackTemplate =
|
|
"operands.append(_get_op_results_or_values({0}))";
|
|
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
|
|
|
|
/// Template for attribute builder from raw input in the operation builder.
|
|
/// {0} is the builder argument name;
|
|
/// {1} is the attribute builder from raw;
|
|
/// {2} is the attribute builder from raw.
|
|
/// Use the value the user passed in if either it is already an Attribute or
|
|
/// there is no method registered to make it an Attribute.
|
|
constexpr const char *initAttributeWithBuilderTemplate =
|
|
R"Py(attributes["{1}"] = ({0} if (
|
|
isinstance({0}, _ods_ir.Attribute) or
|
|
not _ods_ir.AttrBuilder.contains('{2}')) else
|
|
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
|
|
|
|
/// Template for attribute builder from raw input for optional attribute in the
|
|
/// operation builder.
|
|
/// {0} is the builder argument name;
|
|
/// {1} is the attribute builder from raw;
|
|
/// {2} is the attribute builder from raw.
|
|
/// Use the value the user passed in if either it is already an Attribute or
|
|
/// there is no method registered to make it an Attribute.
|
|
constexpr const char *initOptionalAttributeWithBuilderTemplate =
|
|
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
|
|
isinstance({0}, _ods_ir.Attribute) or
|
|
not _ods_ir.AttrBuilder.contains('{2}')) else
|
|
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
|
|
|
|
constexpr const char *initUnitAttributeTemplate =
|
|
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
|
|
_ods_get_default_loc_context(loc)))Py";
|
|
|
|
/// Template to initialize the successors list in the builder if there are any
|
|
/// successors.
|
|
/// {0} is the value to initialize the successors list to.
|
|
constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
|
|
|
|
/// Template to append or extend the list of successors in the builder.
|
|
/// {0} is the list method ('append' or 'extend');
|
|
/// {1} is the value to add.
|
|
constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
|
|
|
|
/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
|
|
/// result types of the given operation.
|
|
static bool hasSameArgumentAndResultTypes(const Operator &op) {
|
|
return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
|
|
op.getNumVariableLengthResults() == 0;
|
|
}
|
|
|
|
/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
|
|
/// result types of the given operation.
|
|
static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
|
|
return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
|
|
op.getNumVariableLengthResults() == 0;
|
|
}
|
|
|
|
/// Returns true if the InferTypeOpInterface can be used to infer result types
|
|
/// of the given operation.
|
|
static bool hasInferTypeInterface(const Operator &op) {
|
|
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
|
|
op.getNumRegions() == 0;
|
|
}
|
|
|
|
/// Returns true if there is a trait or interface that can be used to infer
|
|
/// result types of the given operation.
|
|
static bool canInferType(const Operator &op) {
|
|
return hasSameArgumentAndResultTypes(op) ||
|
|
hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
|
|
}
|
|
|
|
/// Populates `builderArgs` with result names if the builder is expected to
|
|
/// accept them as arguments.
|
|
static void
|
|
populateBuilderArgsResults(const Operator &op,
|
|
SmallVectorImpl<std::string> &builderArgs) {
|
|
if (canInferType(op))
|
|
return;
|
|
|
|
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
|
std::string name = op.getResultName(i).str();
|
|
if (name.empty()) {
|
|
if (op.getNumResults() == 1) {
|
|
// Special case for one result, make the default name be 'result'
|
|
// to properly match the built-in result accessor.
|
|
name = "result";
|
|
} else {
|
|
name = formatv("_gen_res_{0}", i);
|
|
}
|
|
}
|
|
name = sanitizeName(name);
|
|
builderArgs.push_back(name);
|
|
}
|
|
}
|
|
|
|
/// Populates `builderArgs` with the Python-compatible names of builder function
|
|
/// arguments using intermixed attributes and operands in the same order as they
|
|
/// appear in the `arguments` field of the op definition. Additionally,
|
|
/// `operandNames` is populated with names of operands in their order of
|
|
/// appearance.
|
|
static void populateBuilderArgs(const Operator &op,
|
|
SmallVectorImpl<std::string> &builderArgs,
|
|
SmallVectorImpl<std::string> &operandNames) {
|
|
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
|
std::string name = op.getArgName(i).str();
|
|
if (name.empty())
|
|
name = formatv("_gen_arg_{0}", i);
|
|
name = sanitizeName(name);
|
|
builderArgs.push_back(name);
|
|
if (!isa<NamedAttribute *>(op.getArg(i)))
|
|
operandNames.push_back(name);
|
|
}
|
|
}
|
|
|
|
/// Populates `builderArgs` with the Python-compatible names of builder function
|
|
/// successor arguments. Additionally, `successorArgNames` is also populated.
|
|
static void
|
|
populateBuilderArgsSuccessors(const Operator &op,
|
|
SmallVectorImpl<std::string> &builderArgs,
|
|
SmallVectorImpl<std::string> &successorArgNames) {
|
|
|
|
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
|
|
NamedSuccessor successor = op.getSuccessor(i);
|
|
std::string name = std::string(successor.name);
|
|
if (name.empty())
|
|
name = formatv("_gen_successor_{0}", i);
|
|
name = sanitizeName(name);
|
|
builderArgs.push_back(name);
|
|
successorArgNames.push_back(name);
|
|
}
|
|
}
|
|
|
|
/// Populates `builderLines` with additional lines that are required in the
|
|
/// builder to set up operation attributes. `argNames` is expected to contain
|
|
/// the names of builder arguments that correspond to op arguments, i.e. to the
|
|
/// operands and attributes in the same order as they appear in the `arguments`
|
|
/// field.
|
|
static void
|
|
populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
|
|
SmallVectorImpl<std::string> &builderLines) {
|
|
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
|
|
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
|
Argument arg = op.getArg(i);
|
|
auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
|
|
if (!attribute)
|
|
continue;
|
|
|
|
// Unit attributes are handled specially.
|
|
if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
|
builderLines.push_back(
|
|
formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
|
|
continue;
|
|
}
|
|
|
|
// For EnumAttr-style attributes (those defined as EnumAttr<Dialect, ...>
|
|
// in tablegen), use a dialect-qualified key ("dialect.AttrName") so the
|
|
// lookup matches the registration emitted by EnumPythonBindingGen with
|
|
// -bind-dialect. For all other attributes (plain attrs like I32Attr,
|
|
// custom AttrDef, etc.), keep the unqualified name to match their
|
|
// registrations in ir.py or dialect-specific Python files.
|
|
Attribute baseAttr = attribute->attr.getBaseAttr();
|
|
Dialect attrDialect = baseAttr.isSubClassOf("EnumAttr")
|
|
? baseAttr.getDialect()
|
|
: Dialect(nullptr);
|
|
std::string attrBuilderKey = attrDialect
|
|
? formatv("{0}.{1}", attrDialect.getName(),
|
|
attribute->attr.getAttrDefName())
|
|
.str()
|
|
: attribute->attr.getAttrDefName().str();
|
|
|
|
builderLines.push_back(formatv(
|
|
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
|
|
? initOptionalAttributeWithBuilderTemplate
|
|
: initAttributeWithBuilderTemplate,
|
|
argNames[i], attribute->name, attrBuilderKey));
|
|
}
|
|
}
|
|
|
|
/// Populates `builderLines` with additional lines that are required in the
|
|
/// builder to set up successors. successorArgNames is expected to correspond
|
|
/// to the Python argument name for each successor on the op.
|
|
static void
|
|
populateBuilderLinesSuccessors(const Operator &op,
|
|
ArrayRef<std::string> successorArgNames,
|
|
SmallVectorImpl<std::string> &builderLines) {
|
|
if (successorArgNames.empty()) {
|
|
builderLines.push_back(formatv(initSuccessorsTemplate, "None"));
|
|
return;
|
|
}
|
|
|
|
builderLines.push_back(formatv(initSuccessorsTemplate, "[]"));
|
|
for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
|
|
auto &argName = successorArgNames[i];
|
|
const NamedSuccessor &successor = op.getSuccessor(i);
|
|
builderLines.push_back(formatv(addSuccessorTemplate,
|
|
successor.isVariadic() ? "extend" : "append",
|
|
argName));
|
|
}
|
|
}
|
|
|
|
/// Populates `builderLines` with additional lines that are required in the
|
|
/// builder to set up op operands.
|
|
static void
|
|
populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
|
|
SmallVectorImpl<std::string> &builderLines) {
|
|
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
|
|
|
|
// For each element, find or generate a name.
|
|
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
|
const NamedTypeConstraint &element = op.getOperand(i);
|
|
std::string name = names[i];
|
|
|
|
// Choose the formatting string based on the element kind.
|
|
StringRef formatString;
|
|
if (!element.isVariableLength()) {
|
|
formatString = singleOperandAppendTemplate;
|
|
} else if (element.isOptional()) {
|
|
if (sizedSegments) {
|
|
formatString = optionalAppendAttrSizedOperandsTemplate;
|
|
} else {
|
|
formatString = optionalAppendOperandTemplate;
|
|
}
|
|
} else {
|
|
assert(element.isVariadic() && "unhandled element group type");
|
|
// If emitting with sizedSegments, then we add the actual list-typed
|
|
// element. Otherwise, we extend the actual operands.
|
|
if (sizedSegments) {
|
|
formatString = multiOperandAppendPackTemplate;
|
|
} else {
|
|
formatString = multiOperandAppendTemplate;
|
|
}
|
|
}
|
|
|
|
builderLines.push_back(formatv(formatString.data(), name));
|
|
}
|
|
}
|
|
|
|
/// Python code template of generating result types for
|
|
/// FirstAttrDerivedResultType trait
|
|
/// - {0} is the name of the attribute from which to derive the types.
|
|
/// - {1} is the number of results.
|
|
constexpr const char *firstAttrDerivedResultTypeTemplate =
|
|
R"Py(if results is None:
|
|
_ods_result_type_source_attr = attributes["{0}"]
|
|
_ods_derived_result_type = (
|
|
_ods_ir.TypeAttr(_ods_result_type_source_attr).value
|
|
if isinstance(_ods_result_type_source_attr, _ods_ir.TypeAttr) else
|
|
_ods_result_type_source_attr.type)
|
|
results = [_ods_derived_result_type] * {1})Py";
|
|
|
|
/// Python code template of generating result types for
|
|
/// SameOperandsAndResultType trait
|
|
/// - {0} is the number of results.
|
|
constexpr const char *sameOperandsAndResultTypeTemplate =
|
|
R"Py(if results is None: results = [operands[0].type] * {0})Py";
|
|
|
|
/// Appends the given multiline string as individual strings into
|
|
/// `builderLines`.
|
|
static void appendLineByLine(StringRef string,
|
|
SmallVectorImpl<std::string> &builderLines) {
|
|
|
|
std::pair<StringRef, StringRef> split = std::make_pair(string, string);
|
|
do {
|
|
split = split.second.split('\n');
|
|
builderLines.push_back(split.first.str());
|
|
} while (!split.second.empty());
|
|
}
|
|
|
|
/// Populates `builderLines` with additional lines that are required in the
|
|
/// builder to set up op results.
|
|
static void
|
|
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
|
|
SmallVectorImpl<std::string> &builderLines) {
|
|
if (hasSameArgumentAndResultTypes(op)) {
|
|
appendLineByLine(
|
|
formatv(sameOperandsAndResultTypeTemplate, op.getNumResults()).str(),
|
|
builderLines);
|
|
return;
|
|
}
|
|
|
|
if (hasFirstAttrDerivedResultTypes(op)) {
|
|
const NamedAttribute &firstAttr = op.getAttribute(0);
|
|
assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
|
|
"from which the type is derived");
|
|
appendLineByLine(formatv(firstAttrDerivedResultTypeTemplate, firstAttr.name,
|
|
op.getNumResults())
|
|
.str(),
|
|
builderLines);
|
|
return;
|
|
}
|
|
|
|
if (hasInferTypeInterface(op))
|
|
return;
|
|
|
|
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
|
|
builderLines.push_back("results = []");
|
|
|
|
// For each element, find or generate a name.
|
|
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
|
const NamedTypeConstraint &element = op.getResult(i);
|
|
std::string name = names[i];
|
|
|
|
// Choose the formatting string based on the element kind.
|
|
StringRef formatString;
|
|
if (!element.isVariableLength()) {
|
|
formatString = singleResultAppendTemplate;
|
|
} else if (element.isOptional()) {
|
|
formatString = optionalAppendResultTemplate;
|
|
} else {
|
|
assert(element.isVariadic() && "unhandled element group type");
|
|
// If emitting with sizedSegments, then we add the actual list-typed
|
|
// element. Otherwise, we extend the actual operands.
|
|
if (sizedSegments) {
|
|
formatString = singleResultAppendTemplate;
|
|
} else {
|
|
formatString = multiResultAppendTemplate;
|
|
}
|
|
}
|
|
|
|
builderLines.push_back(formatv(formatString.data(), name));
|
|
}
|
|
}
|
|
|
|
/// If the operation has variadic regions, adds a builder argument to specify
|
|
/// the number of those regions and builder lines to forward it to the generic
|
|
/// constructor.
|
|
static void populateBuilderRegions(const Operator &op,
|
|
SmallVectorImpl<std::string> &builderArgs,
|
|
SmallVectorImpl<std::string> &builderLines) {
|
|
if (op.hasNoVariadicRegions())
|
|
return;
|
|
|
|
// This is currently enforced when Operator is constructed.
|
|
assert(op.getNumVariadicRegions() == 1 &&
|
|
op.getRegion(op.getNumRegions() - 1).isVariadic() &&
|
|
"expected the last region to be varidic");
|
|
|
|
const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1);
|
|
std::string name =
|
|
("num_" + region.name.take_front().lower() + region.name.drop_front())
|
|
.str();
|
|
builderArgs.push_back(name);
|
|
builderLines.push_back(
|
|
formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
|
|
}
|
|
|
|
/// Emits a default builder constructing an operation from the list of its
|
|
/// result types, followed by a list of its operands. Returns vector
|
|
/// of fully built functionArgs for downstream users (to save having to
|
|
/// rebuild anew).
|
|
static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
|
|
raw_ostream &os) {
|
|
SmallVector<std::string> builderArgs;
|
|
SmallVector<std::string> builderLines;
|
|
SmallVector<std::string> operandArgNames;
|
|
SmallVector<std::string> successorArgNames;
|
|
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
|
|
op.getNumNativeAttributes() + op.getNumSuccessors());
|
|
populateBuilderArgsResults(op, builderArgs);
|
|
size_t numResultArgs = builderArgs.size();
|
|
populateBuilderArgs(op, builderArgs, operandArgNames);
|
|
size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
|
|
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
|
|
size_t numSuccessorArgs = successorArgNames.size();
|
|
|
|
populateBuilderLinesOperand(op, operandArgNames, builderLines);
|
|
populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
|
|
builderLines);
|
|
populateBuilderLinesResult(
|
|
op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
|
|
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
|
|
populateBuilderRegions(op, builderArgs, builderLines);
|
|
|
|
// Compute type annotations for each builder arg.
|
|
SmallVector<std::string> argTypes(builderArgs.size());
|
|
|
|
// Result args: user passes Type objects.
|
|
for (size_t i = 0; i < numResultArgs; ++i) {
|
|
const NamedTypeConstraint &result = op.getResult(i);
|
|
if (result.isVariadic())
|
|
argTypes[i] = "_Sequence[_ods_ir.Type]";
|
|
else if (result.isOptional())
|
|
argTypes[i] = "_Optional[_ods_ir.Type]";
|
|
else
|
|
argTypes[i] = "_ods_ir.Type";
|
|
}
|
|
|
|
// Operand and attribute args.
|
|
for (size_t i = 0; i < numOperandAttrArgs; ++i) {
|
|
size_t idx = numResultArgs + i;
|
|
Argument arg = op.getArg(i);
|
|
if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(arg)) {
|
|
if (nattr->attr.getStorageType().trim() == "::mlir::UnitAttr") {
|
|
argTypes[idx] = "bool";
|
|
} else {
|
|
std::string attrType = "_ods_ir." + getPythonAttrName(nattr->attr);
|
|
StringRef rawType = getPythonAttrRawType(nattr->attr);
|
|
argTypes[idx] =
|
|
llvm::formatv("_Union[{0}, {1}]",
|
|
rawType.empty() ? "_Any" : rawType, attrType)
|
|
.str();
|
|
}
|
|
} else if (auto *ntype =
|
|
llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg)) {
|
|
std::string type = "_ods_ir.Value";
|
|
if (StringRef pythonType = getPythonType(ntype->constraint.getCppType());
|
|
!pythonType.empty()) {
|
|
type = llvm::formatv("{0}[{1}]", type, pythonType);
|
|
}
|
|
if (ntype->isVariadic())
|
|
type = llvm::formatv("_Sequence[{0}]", type);
|
|
argTypes[idx] = type;
|
|
}
|
|
// NamedProperty args are skipped (no type hint).
|
|
}
|
|
|
|
// Successor args.
|
|
for (size_t i = 0; i < numSuccessorArgs; ++i) {
|
|
size_t idx = numResultArgs + numOperandAttrArgs + i;
|
|
const NamedSuccessor &successor = op.getSuccessor(i);
|
|
argTypes[idx] =
|
|
successor.isVariadic() ? "_Sequence[_ods_ir.Block]" : "_ods_ir.Block";
|
|
}
|
|
|
|
// Region args (variadic region count).
|
|
for (size_t i = numResultArgs + numOperandAttrArgs + numSuccessorArgs;
|
|
i < builderArgs.size(); ++i) {
|
|
argTypes[i] = "int";
|
|
}
|
|
|
|
// Determine whether the argument corresponding to a given index into the
|
|
// builderArgs vector is a python keyword argument or not.
|
|
auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
|
|
// All result, successor, and region arguments are positional arguments.
|
|
if (builderArgIndex < numResultArgs ||
|
|
builderArgIndex >= numResultArgs + numOperandAttrArgs)
|
|
return false;
|
|
// Keyword arguments:
|
|
// - optional named attributes (including unit attributes)
|
|
// - default-valued named attributes
|
|
// - optional operands
|
|
Argument a = op.getArg(builderArgIndex - numResultArgs);
|
|
if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
|
|
return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
|
|
if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
|
|
return ntype->isOptional();
|
|
return false;
|
|
};
|
|
|
|
// Format a single function argument with optional type hint and default.
|
|
auto formatArg = [](StringRef name, StringRef typeHint,
|
|
bool isKeyword) -> std::string {
|
|
std::string result = name.str();
|
|
if (isKeyword && !typeHint.empty())
|
|
result += ": _Optional[" + typeHint.str() + "] = None";
|
|
else if (isKeyword)
|
|
result += "=None";
|
|
else if (!typeHint.empty())
|
|
result += ": " + typeHint.str();
|
|
return result;
|
|
};
|
|
|
|
// Build the function argument list: positional args, *, keyword args.
|
|
SmallVector<std::string> functionArgs;
|
|
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i)
|
|
if (!isKeywordArgFn(i))
|
|
functionArgs.push_back(formatArg(builderArgs[i], argTypes[i], false));
|
|
|
|
// Add a bare '*' to indicate that all following arguments must be keyword
|
|
// arguments.
|
|
functionArgs.push_back("*");
|
|
|
|
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i)
|
|
if (isKeywordArgFn(i))
|
|
functionArgs.push_back(formatArg(builderArgs[i], argTypes[i], true));
|
|
|
|
if (canInferType(op))
|
|
functionArgs.push_back(
|
|
"results: _Optional[_Sequence[_ods_ir.Type]] = None");
|
|
functionArgs.push_back("loc: _Optional[_ods_ir.Location] = None");
|
|
functionArgs.push_back("ip: _Optional[_ods_ir.InsertionPoint] = None");
|
|
|
|
SmallVector<std::string> initArgs;
|
|
initArgs.push_back("self.OPERATION_NAME");
|
|
initArgs.push_back("self._ODS_REGIONS");
|
|
initArgs.push_back("self._ODS_OPERAND_SEGMENTS");
|
|
initArgs.push_back("self._ODS_RESULT_SEGMENTS");
|
|
initArgs.push_back("attributes=attributes");
|
|
initArgs.push_back("results=results");
|
|
initArgs.push_back("operands=operands");
|
|
initArgs.push_back("successors=_ods_successors");
|
|
initArgs.push_back("regions=regions");
|
|
initArgs.push_back("loc=loc");
|
|
initArgs.push_back("ip=ip");
|
|
|
|
os << formatv(initTemplate, llvm::join(functionArgs, ", "),
|
|
llvm::join(builderLines, "\n "), llvm::join(initArgs, ", "));
|
|
return functionArgs;
|
|
}
|
|
|
|
static void emitSegmentSpec(
|
|
const Operator &op, const char *kind,
|
|
llvm::function_ref<int(const Operator &)> getNumElements,
|
|
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
|
|
getElement,
|
|
raw_ostream &os) {
|
|
std::string segmentSpec("[");
|
|
for (int i = 0, e = getNumElements(op); i < e; ++i) {
|
|
const NamedTypeConstraint &element = getElement(op, i);
|
|
if (element.isOptional()) {
|
|
segmentSpec.append("0,");
|
|
} else if (element.isVariadic()) {
|
|
segmentSpec.append("-1,");
|
|
} else {
|
|
segmentSpec.append("1,");
|
|
}
|
|
}
|
|
segmentSpec.append("]");
|
|
|
|
os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
|
|
}
|
|
|
|
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
|
|
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
|
|
// Note that the base OpView class defines this as (0, True).
|
|
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
|
|
os << formatv(opClassRegionSpecTemplate, minRegionCount,
|
|
op.hasNoVariadicRegions() ? "True" : "False");
|
|
}
|
|
|
|
/// Emits named accessors to regions.
|
|
static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
|
|
for (const auto &en : llvm::enumerate(op.getRegions())) {
|
|
const NamedRegion ®ion = en.value();
|
|
if (region.name.empty())
|
|
continue;
|
|
|
|
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
|
|
"expected only the last region to be variadic");
|
|
os << formatv(regionAccessorTemplate, sanitizeName(region.name),
|
|
std::to_string(en.index()) + (region.isVariadic() ? ":" : ""),
|
|
region.isVariadic() ? "_ods_ir.RegionSequence"
|
|
: "_ods_ir.Region");
|
|
}
|
|
}
|
|
|
|
/// Emits builder that extracts results from op
|
|
static void emitValueBuilder(const Operator &op,
|
|
SmallVector<std::string> functionArgs,
|
|
raw_ostream &os) {
|
|
// Parse a formatted function arg "name[: type][ = default]" into
|
|
// (name, type, defaultVal) with whitespace trimmed.
|
|
auto parseFunctionArg =
|
|
[](StringRef arg) -> std::tuple<StringRef, StringRef, StringRef> {
|
|
auto [nameAndType, defaultVal] = arg.split('=');
|
|
auto [name, type] = nameAndType.split(':');
|
|
return {name.trim(), type.trim(), defaultVal.trim()};
|
|
};
|
|
|
|
// Params with (possibly) default args.
|
|
auto valueBuilderParams =
|
|
llvm::map_range(functionArgs, [&](const std::string &arg) {
|
|
auto [name, type, defaultVal] = parseFunctionArg(arg);
|
|
std::string result = llvm::convertToSnakeFromCamelCase(name);
|
|
if (!type.empty())
|
|
result += ": " + type.str();
|
|
if (!defaultVal.empty())
|
|
result += " = " + defaultVal.str();
|
|
return result;
|
|
});
|
|
// Actual args passed to op builder (e.g., opParam=op_param).
|
|
auto opBuilderArgs = llvm::map_range(
|
|
llvm::make_filter_range(functionArgs,
|
|
[](const std::string &s) { return s != "*"; }),
|
|
[&](const std::string &arg) {
|
|
auto [name, type, defaultVal] = parseFunctionArg(arg);
|
|
return (name + "=" + llvm::convertToSnakeFromCamelCase(name)).str();
|
|
});
|
|
std::string nameWithoutDialect = sanitizeName(
|
|
op.getOperationName().substr(op.getOperationName().find('.') + 1));
|
|
if (nameWithoutDialect == op.getCppClassName())
|
|
nameWithoutDialect += "_";
|
|
std::string params = llvm::join(valueBuilderParams, ", ");
|
|
std::string args = llvm::join(opBuilderArgs, ", ");
|
|
if (op.getNumVariableLengthResults()) {
|
|
os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
|
|
op.getCppClassName(), params, args);
|
|
} else {
|
|
std::string type = op.getCppClassName().str();
|
|
const char *results = "";
|
|
if (op.getNumResults() > 1) {
|
|
type = "_ods_ir.OpResultList";
|
|
results = ".results";
|
|
} else if (op.getNumResults() == 1) {
|
|
type = "_ods_ir.OpResult";
|
|
if (StringRef pythonType =
|
|
getPythonType(op.getResult(0).constraint.getCppType());
|
|
!pythonType.empty())
|
|
type = llvm::formatv("{0}[{1}]", type, pythonType);
|
|
results = ".result";
|
|
}
|
|
os << formatv(valueBuilderTemplate, nameWithoutDialect,
|
|
op.getCppClassName(), params, args, type, results);
|
|
}
|
|
}
|
|
|
|
/// Retrieve the description of the given op and generate a docstring for it.
|
|
static std::string makeDocStringForOp(const Operator &op) {
|
|
if (!op.hasDescription())
|
|
return "";
|
|
|
|
auto desc = op.getDescription().rtrim(" \t").str();
|
|
// Replace all """ with \"\"\" to avoid early termination of the literal.
|
|
desc = std::regex_replace(desc, std::regex(R"(""")"), R"(\"\"\")");
|
|
|
|
std::string docString = "\n";
|
|
llvm::raw_string_ostream os(docString);
|
|
raw_indented_ostream identedOs(os);
|
|
os << R"( r""")" << "\n";
|
|
identedOs.printReindented(desc, " ");
|
|
if (!StringRef(desc).ends_with("\n"))
|
|
os << "\n";
|
|
os << R"( """)" << "\n";
|
|
|
|
return docString;
|
|
}
|
|
|
|
static void emitAdaptorOperandAccessors(const Operator &op, raw_ostream &os) {
|
|
emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
|
|
getNumOperands(op), getOperand, /*isAdaptor=*/true);
|
|
}
|
|
|
|
/// Emits bindings for a specific Op to the given output stream.
|
|
static void emitOpBindings(const Operator &op, raw_ostream &os) {
|
|
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(),
|
|
makeDocStringForOp(op));
|
|
|
|
// Sized segments.
|
|
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
|
|
emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
|
|
}
|
|
if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
|
|
emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
|
|
}
|
|
|
|
emitRegionAttributes(op, os);
|
|
SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
|
|
emitOperandAccessors(op, os);
|
|
emitAttributeAccessors(op, os);
|
|
emitResultAccessors(op, os);
|
|
emitRegionAccessors(op, os);
|
|
|
|
os << formatv(opAdaptorClassTemplate, op.getCppClassName(),
|
|
op.getOperationName());
|
|
emitAdaptorOperandAccessors(op, os);
|
|
emitAdaptorAttributeAccessors(op, os);
|
|
|
|
emitValueBuilder(op, functionArgs, os);
|
|
}
|
|
|
|
/// Emits bindings for the dialect specified in the command line, including file
|
|
/// headers and utilities. Returns `false` on success to comply with Tablegen
|
|
/// registration requirements.
|
|
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
|
|
if (dialectNameStorage.empty())
|
|
llvm::PrintFatalError("dialect name not provided");
|
|
|
|
os << fileHeader;
|
|
if (!clDialectExtensionName.empty())
|
|
os << formatv(dialectExtensionTemplate, dialectNameStorage);
|
|
else
|
|
os << formatv(dialectClassTemplate, dialectNameStorage);
|
|
|
|
for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
|
|
Operator op(rec);
|
|
if (op.getDialectName() == dialectNameStorage)
|
|
emitOpBindings(op, os);
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static GenRegistration
|
|
genPythonBindings("gen-python-op-bindings",
|
|
"Generate Python bindings for MLIR Ops", &emitAllOps);
|