[mlir][ods] Support string literals in custom
directives
This patch adds support for string literals as `custom` directive arguments. This can be useful for re-using custom parsers and printers when arguments have a known value. For example: ``` ParseResult parseTypedAttr(AsmParser &parser, Attribute &attr, Type type) { return parser.parseAttribute(attr, type); } void printTypedAttr(AsmPrinter &printer, Attribute attr, Type type) { return parser.printAttributeWithoutType(attr); } ``` And in TableGen: ``` def FooOp : ... { let arguments = (ins AnyAttr:$a); let assemblyFormat = [{ custom<TypedAttr>($a, "$_builder.getI1Type()") attr-dict }]; } def BarOp : ... { let arguments = (ins AnyAttr:$a); let assemblyFormat = [{ custom<TypedAttr>($a, "$_builder.getIndexType()") attr-dict }]; } ``` Instead of writing two separate sets of custom parsers and printers. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D131603
This commit is contained in:
parent
2ca27206f9
commit
a2ad3ec7ac
@ -895,6 +895,19 @@ void printStringParam(AsmPrinter &printer, StringRef value);
|
||||
The custom parser is considered to have failed if it returns failure or if any
|
||||
bound parameters have failure values afterwards.
|
||||
|
||||
A string of C++ code can be used as a `custom` directive argument. When
|
||||
generating the custom parser and printer call, the string is pasted as a
|
||||
function argument. For example, `parseBar` and `printBar` can be re-used with
|
||||
a constant integer:
|
||||
|
||||
```tablegen
|
||||
let parameters = (ins "int":$bar);
|
||||
let assemblyFormat = [{ custom<Bar>($foo, "1") }];
|
||||
```
|
||||
|
||||
The string is pasted verbatim but with substitutions for `$_builder` and
|
||||
`$_ctxt`. String literals can be used to parameterize custom directives.
|
||||
|
||||
### Verification
|
||||
|
||||
If the `genVerifyDecl` field is set, additional verification methods are
|
||||
|
@ -768,9 +768,9 @@ when generating the C++ code for the format. The `UserDirective` is an
|
||||
identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
|
||||
would result in calls to `parseMyDirective` and `printMyDirective` within the
|
||||
parser and printer respectively. `Params` may be any combination of variables
|
||||
(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`.
|
||||
The type directives must refer to a variable, but that variable need not also be
|
||||
a parameter to the custom directive.
|
||||
(i.e. Attribute, Operand, Successor, etc.), type directives, `attr-dict`, and
|
||||
strings of C++ code. The type directives must refer to a variable, but that
|
||||
variable need not also be a parameter to the custom directive.
|
||||
|
||||
The arguments to the `parse<UserDirective>` method are firstly a reference to
|
||||
the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
|
||||
@ -837,7 +837,16 @@ declarative parameter to `print` method argument is detailed below:
|
||||
- VariadicOfVariadic: `TypeRangeRange`
|
||||
* `attr-dict` Directive: `DictionaryAttr`
|
||||
|
||||
When a variable is optional, the provided value may be null.
|
||||
When a variable is optional, the provided value may be null. When a variable is
|
||||
referenced in a custom directive parameter using `ref`, it is passed in by
|
||||
value. Referenced variables to `print<UserDirective>` are passed as the same as
|
||||
bound variables, but referenced variables to `parse<UserDirective>` are passed
|
||||
like to the printer.
|
||||
|
||||
A custom directive can take a string of C++ code as a parameter. The code is
|
||||
pasted verbatim in the calls to the custom parser and printers, with the
|
||||
substitutions `$_builder` and `$_ctxt`. String literals can be used to
|
||||
parameterize custom directives.
|
||||
|
||||
#### Optional Groups
|
||||
|
||||
@ -1462,7 +1471,7 @@ std::string stringifyMyBitEnum(MyBitEnum symbol) {
|
||||
if (2u == (2u & val)) { strs.push_back("Bit1"); }
|
||||
if (4u == (4u & val)) { strs.push_back("Bit2"); }
|
||||
if (8u == (8u & val)) { strs.push_back("Bit3"); }
|
||||
|
||||
|
||||
return llvm::join(strs, "|");
|
||||
}
|
||||
|
||||
|
@ -45,8 +45,9 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
|
||||
let results = (outs AnyTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
|
||||
`:` type($result)
|
||||
custom<DynamicIndexList>($sizes, $static_sizes,
|
||||
"ShapedType::kDynamicSize")
|
||||
attr-dict `:` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -1023,11 +1023,14 @@ def MemRef_ReinterpretCastOp
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `to` `offset` `` `:`
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<DynamicIndexList>($offsets, $static_offsets,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
`` `,` `sizes` `` `:`
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) `` `,` `strides`
|
||||
`` `:`
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
custom<DynamicIndexList>($sizes, $static_sizes,
|
||||
"ShapedType::kDynamicSize")
|
||||
`` `,` `strides` `` `:`
|
||||
custom<DynamicIndexList>($strides, $static_strides,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
|
||||
@ -1586,9 +1589,12 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
custom<DynamicIndexList>($offsets, $static_offsets,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
custom<DynamicIndexList>($sizes, $static_sizes,
|
||||
"ShapedType::kDynamicSize")
|
||||
custom<DynamicIndexList>($strides, $static_strides,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
|
||||
|
@ -219,11 +219,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
|
||||
To disambiguate, the inference helpers `inferCanonicalRankReducedResultType`
|
||||
only drop the first unit dimensions, in order:
|
||||
e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6.
|
||||
|
||||
|
||||
Verification however has access to result type and does not need to infer.
|
||||
The verifier calls `isRankReducedType(getSource(), getResult())` to
|
||||
The verifier calls `isRankReducedType(getSource(), getResult())` to
|
||||
determine whether the result type is rank-reduced from the source type.
|
||||
This computes a so-called rank-reduction mask, consisting of dropped unit
|
||||
This computes a so-called rank-reduction mask, consisting of dropped unit
|
||||
dims, to map the rank-reduced type to the source type by dropping ones:
|
||||
e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2}
|
||||
6x1 is a rank-reduced version of 1x6x1 by mask {0}
|
||||
@ -254,9 +254,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
custom<DynamicIndexList>($offsets, $static_offsets,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
custom<DynamicIndexList>($sizes, $static_sizes,
|
||||
"ShapedType::kDynamicSize")
|
||||
custom<DynamicIndexList>($strides, $static_strides,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
|
||||
@ -298,12 +301,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
|
||||
/// tensor type to the result tensor type by dropping unit dims.
|
||||
llvm::Optional<llvm::SmallDenseSet<unsigned>>
|
||||
computeRankReductionMask() {
|
||||
return ::mlir::computeRankReductionMask(getSourceType().getShape(),
|
||||
return ::mlir::computeRankReductionMask(getSourceType().getShape(),
|
||||
getType().getShape());
|
||||
};
|
||||
|
||||
/// An extract_slice result type can be inferred, when it is not
|
||||
/// rank-reduced, from the source type and the static representation of
|
||||
/// rank-reduced, from the source type and the static representation of
|
||||
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
|
||||
static RankedTensorType inferResultType(
|
||||
ShapedType sourceShapedTensorType,
|
||||
@ -580,9 +583,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source `into` $dest ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
custom<DynamicIndexList>($offsets, $static_offsets,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
custom<DynamicIndexList>($sizes, $static_sizes,
|
||||
"ShapedType::kDynamicSize")
|
||||
custom<DynamicIndexList>($strides, $static_strides,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
attr-dict `:` type($source) `into` type($dest)
|
||||
}];
|
||||
|
||||
@ -608,7 +614,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
|
||||
RankedTensorType getType() {
|
||||
return getResult().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
|
||||
/// The `dest` type is the same as the result type.
|
||||
RankedTensorType getDestType() {
|
||||
return getType();
|
||||
@ -962,8 +968,10 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
||||
let assemblyFormat = [{
|
||||
$source
|
||||
(`nofold` $nofold^)?
|
||||
`low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
|
||||
`high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
|
||||
`low` `` custom<DynamicIndexList>($low, $static_low,
|
||||
"ShapedType::kDynamicSize")
|
||||
`high` `` custom<DynamicIndexList>($high, $static_high,
|
||||
"ShapedType::kDynamicSize")
|
||||
$region attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
|
||||
@ -1069,15 +1077,15 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
||||
// HasParent<"ParallelCombiningOpInterface">
|
||||
]> {
|
||||
let summary = [{
|
||||
Specify the tensor slice update of a single thread of a parent
|
||||
Specify the tensor slice update of a single thread of a parent
|
||||
ParallelCombiningOpInterface op.
|
||||
}];
|
||||
let description = [{
|
||||
The `parallel_insert_slice` yields a subset tensor value to its parent
|
||||
The `parallel_insert_slice` yields a subset tensor value to its parent
|
||||
ParallelCombiningOpInterface. These subset tensor values are aggregated to
|
||||
in some unspecified order into a full tensor value returned by the parent
|
||||
parallel iterating op.
|
||||
The `parallel_insert_slice` is one such op allowed in the
|
||||
in some unspecified order into a full tensor value returned by the parent
|
||||
parallel iterating op.
|
||||
The `parallel_insert_slice` is one such op allowed in the
|
||||
ParallelCombiningOpInterface op.
|
||||
|
||||
Conflicting writes result in undefined semantics, in that the indices written
|
||||
@ -1118,12 +1126,12 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
||||
into a memref.subview op.
|
||||
|
||||
A parallel_insert_slice operation may additionally specify insertion into a
|
||||
tensor of higher rank than the source tensor, along dimensions that are
|
||||
tensor of higher rank than the source tensor, along dimensions that are
|
||||
statically known to be of size 1.
|
||||
This rank-altering behavior is not required by the op semantics: this
|
||||
flexibility allows to progressively drop unit dimensions while lowering
|
||||
between different flavors of ops on that operate on tensors.
|
||||
The rank-altering behavior of tensor.parallel_insert_slice matches the
|
||||
The rank-altering behavior of tensor.parallel_insert_slice matches the
|
||||
rank-reducing behavior of tensor.insert_slice and tensor.extract_slice.
|
||||
|
||||
Verification in the rank-reduced case:
|
||||
@ -1144,9 +1152,12 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$source `into` $dest ``
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
|
||||
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
|
||||
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
|
||||
custom<DynamicIndexList>($offsets, $static_offsets,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
custom<DynamicIndexList>($sizes, $static_sizes,
|
||||
"ShapedType::kDynamicSize")
|
||||
custom<DynamicIndexList>($strides, $static_strides,
|
||||
"ShapedType::kDynamicStrideOrOffset")
|
||||
attr-dict `:` type($source) `into` type($dest)
|
||||
}];
|
||||
|
||||
@ -1194,7 +1205,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
|
||||
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
|
||||
];
|
||||
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
@ -84,73 +84,40 @@ namespace mlir {
|
||||
|
||||
/// Printer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
|
||||
/// custom<DynamicIndexList>($values, $integers)
|
||||
///
|
||||
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
|
||||
/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
|
||||
/// either (1) the static integer value in `integers` if the value is
|
||||
/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise. This
|
||||
/// allows idiomatic printing of mixed value and integer attributes in a
|
||||
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
|
||||
void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer,
|
||||
Operation *op,
|
||||
OperandRange values,
|
||||
ArrayAttr integers);
|
||||
|
||||
/// Printer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
/// custom<OperandsOrIntegersSizesList>($values, $integers)
|
||||
///
|
||||
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
|
||||
/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
|
||||
/// either (1) the static integer value in `integers` if the value is
|
||||
/// ShapedType::kDynamicSize or (2) the next value otherwise. This
|
||||
/// allows idiomatic printing of mixed value and integer attributes in a
|
||||
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
|
||||
void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op,
|
||||
OperandRange values, ArrayAttr integers);
|
||||
/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value
|
||||
/// in `integers` is `dynVal` or (2) the next value otherwise. This allows
|
||||
/// idiomatic printing of mixed value and integer attributes in a list. E.g.
|
||||
/// `[%arg0, 7, 42, %arg42]`.
|
||||
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
OperandRange values, ArrayAttr integers,
|
||||
int64_t dynVal);
|
||||
|
||||
/// Pasrer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
|
||||
/// custom<DynamicIndexList>($values, $integers)
|
||||
///
|
||||
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
|
||||
/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
|
||||
/// either (1) static integer values or (2) SSA values. Fill `integers` with
|
||||
/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the
|
||||
/// position of SSA values. Add the parsed SSA values to `values` in-order.
|
||||
/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer
|
||||
/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where
|
||||
/// `dynVal` encodes the position of SSA values. Add the parsed SSA values
|
||||
/// to `values` in-order.
|
||||
//
|
||||
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
|
||||
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
|
||||
/// 2. `ssa` is filled with "[%arg0, %arg1]".
|
||||
ParseResult parseOperandsOrIntegersOffsetsOrStridesList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
ArrayAttr &integers);
|
||||
|
||||
/// Pasrer hook for custom directive in assemblyFormat.
|
||||
///
|
||||
/// custom<OperandsOrIntegersSizesList>($values, $integers)
|
||||
///
|
||||
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
|
||||
/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
|
||||
/// either (1) static integer values or (2) SSA values. Fill `integers` with
|
||||
/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the
|
||||
/// position of SSA values. Add the parsed SSA values to `values` in-order.
|
||||
//
|
||||
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
|
||||
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
|
||||
/// 2. `ssa` is filled with "[%arg0, %arg1]".
|
||||
ParseResult parseOperandsOrIntegersSizesList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
ArrayAttr &integers);
|
||||
ParseResult
|
||||
parseDynamicIndexList(OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
ArrayAttr &integers, int64_t dynVal);
|
||||
|
||||
/// Verify that a the `values` has as many elements as the number of entries in
|
||||
/// `attr` for which `isDynamic` evaluates to true.
|
||||
LogicalResult verifyListOfOperandsOrIntegers(
|
||||
Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
|
||||
ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic);
|
||||
ValueRange values, function_ref<bool(int64_t)> isDynamic);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -987,7 +987,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
|
||||
if (parser.parseOperand(target) ||
|
||||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
|
||||
parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
|
||||
parseDynamicIndexList(parser, dynamicSizes, staticSizes,
|
||||
ShapedType::kDynamicSize) ||
|
||||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
|
||||
parser.parseOptionalAttrDict(result.attributes))
|
||||
return ParseResult::failure();
|
||||
@ -1001,8 +1002,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
|
||||
|
||||
void TileOp::print(OpAsmPrinter &p) {
|
||||
p << ' ' << getTarget();
|
||||
printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
|
||||
getStaticSizes());
|
||||
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
|
||||
ShapedType::kDynamicSize);
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
|
||||
}
|
||||
|
||||
|
@ -70,45 +70,29 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
template <int64_t dynVal>
|
||||
static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
|
||||
ArrayAttr arrayAttr) {
|
||||
p << '[';
|
||||
if (arrayAttr.empty()) {
|
||||
p << "]";
|
||||
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
|
||||
OperandRange values, ArrayAttr integers,
|
||||
int64_t dynVal) {
|
||||
printer << '[';
|
||||
if (integers.empty()) {
|
||||
printer << "]";
|
||||
return;
|
||||
}
|
||||
unsigned idx = 0;
|
||||
llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
|
||||
llvm::interleaveComma(integers, printer, [&](Attribute a) {
|
||||
int64_t val = a.cast<IntegerAttr>().getInt();
|
||||
if (val == dynVal)
|
||||
p << values[idx++];
|
||||
printer << values[idx++];
|
||||
else
|
||||
p << val;
|
||||
printer << val;
|
||||
});
|
||||
p << ']';
|
||||
printer << ']';
|
||||
}
|
||||
|
||||
void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
|
||||
Operation *op,
|
||||
OperandRange values,
|
||||
ArrayAttr integers) {
|
||||
return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
|
||||
p, values, integers);
|
||||
}
|
||||
|
||||
void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
|
||||
OperandRange values,
|
||||
ArrayAttr integers) {
|
||||
return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
|
||||
integers);
|
||||
}
|
||||
|
||||
template <int64_t dynVal>
|
||||
static ParseResult parseOperandsOrIntegersImpl(
|
||||
ParseResult mlir::parseDynamicIndexList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
ArrayAttr &integers) {
|
||||
ArrayAttr &integers, int64_t dynVal) {
|
||||
if (failed(parser.parseLSquare()))
|
||||
return failure();
|
||||
// 0-D.
|
||||
@ -142,22 +126,6 @@ static ParseResult parseOperandsOrIntegersImpl(
|
||||
return success();
|
||||
}
|
||||
|
||||
ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
ArrayAttr &integers) {
|
||||
return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
|
||||
parser, values, integers);
|
||||
}
|
||||
|
||||
ParseResult mlir::parseOperandsOrIntegersSizesList(
|
||||
OpAsmParser &parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
|
||||
ArrayAttr &integers) {
|
||||
return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
|
||||
integers);
|
||||
}
|
||||
|
||||
bool mlir::detail::sameOffsetsSizesAndStrides(
|
||||
OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
|
||||
llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {
|
||||
|
@ -555,3 +555,19 @@ def TypeK : TestType<"TestM"> {
|
||||
let mnemonic = "type_k";
|
||||
let assemblyFormat = "$a";
|
||||
}
|
||||
|
||||
// TYPE-LABEL: ::mlir::Type TestNType::parse
|
||||
// TYPE: parseFoo(
|
||||
// TYPE-NEXT: _result_a,
|
||||
// TYPE-NEXT: 1);
|
||||
|
||||
// TYPE-LABEL: void TestNType::print
|
||||
// TYPE: printFoo(
|
||||
// TYPE-NEXT: getA(),
|
||||
// TYPE-NEXT: 1);
|
||||
|
||||
def TypeL : TestType<"TestN"> {
|
||||
let parameters = (ins "int":$a);
|
||||
let mnemonic = "type_l";
|
||||
let assemblyFormat = [{ custom<Foo>($a, "1") }];
|
||||
}
|
||||
|
@ -403,6 +403,13 @@ def OptionalInvalidP : TestFormat_Op<[{
|
||||
($arg^):(`test`)
|
||||
}]>, Arguments<(ins Variadic<I64>:$arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Strings
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: error: strings may only be used as 'custom' directive arguments
|
||||
def StringInvalidA : TestFormat_Op<[{ "foo" }]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Variables
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -135,6 +135,13 @@ def OptionalValidA : TestFormat_Op<[{
|
||||
(` ` `` $arg^)? attr-dict
|
||||
}]>, Arguments<(ins Optional<I32>:$arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Strings
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-NOT: error
|
||||
def StringInvalidA : TestFormat_Op<[{ custom<Foo>("foo") attr-dict }]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Variables
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
42
mlir/test/mlir-tblgen/op-format.td
Normal file
42
mlir/test/mlir-tblgen/op-format.td
Normal file
@ -0,0 +1,42 @@
|
||||
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def TestDialect : Dialect {
|
||||
let name = "test";
|
||||
}
|
||||
class TestFormat_Op<string fmt, list<Trait> traits = []>
|
||||
: Op<TestDialect, "format_op", traits> {
|
||||
let assemblyFormat = fmt;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Directives
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// custom
|
||||
|
||||
// CHECK-LABEL: CustomStringLiteralA::parse
|
||||
// CHECK: parseFoo({{.*}}, parser.getBuilder().getI1Type())
|
||||
// CHECK-LABEL: CustomStringLiteralA::print
|
||||
// CHECK: printFoo({{.*}}, parser.getBuilder().getI1Type())
|
||||
def CustomStringLiteralA : TestFormat_Op<[{
|
||||
custom<Foo>("$_builder.getI1Type()") attr-dict
|
||||
}]>;
|
||||
|
||||
// CHECK-LABEL: CustomStringLiteralB::parse
|
||||
// CHECK: parseFoo({{.*}}, IndexType::get(parser.getContext()))
|
||||
// CHECK-LABEL: CustomStringLiteralB::print
|
||||
// CHECK: printFoo({{.*}}, IndexType::get(parser.getContext()))
|
||||
def CustomStringLiteralB : TestFormat_Op<[{
|
||||
custom<Foo>("IndexType::get($_ctxt)") attr-dict
|
||||
}]>;
|
||||
|
||||
// CHECK-LABEL: CustomStringLiteralC::parse
|
||||
// CHECK: parseFoo({{.*}}, parser.getBuilder().getStringAttr("foo"))
|
||||
// CHECK-LABEL: CustomStringLiteralC::print
|
||||
// CHECK: printFoo({{.*}}, parser.getBuilder().getStringAttr("foo"))
|
||||
def CustomStringLiteralC : TestFormat_Op<[{
|
||||
custom<Foo>("$_builder.getStringAttr(\"foo\")") attr-dict
|
||||
}]>;
|
@ -629,14 +629,12 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
|
||||
os.indent();
|
||||
for (FormatElement *arg : el->getArguments()) {
|
||||
os << ",\n";
|
||||
FormatElement *param;
|
||||
if (auto *ref = dyn_cast<RefDirective>(arg)) {
|
||||
os << "*";
|
||||
param = ref->getArg();
|
||||
} else {
|
||||
param = arg;
|
||||
}
|
||||
os << "_result_" << cast<ParameterElement>(param)->getName();
|
||||
if (auto *param = dyn_cast<ParameterElement>(arg))
|
||||
os << "_result_" << param->getName();
|
||||
else if (auto *ref = dyn_cast<RefDirective>(arg))
|
||||
os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
|
||||
else
|
||||
os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
|
||||
}
|
||||
os.unindent() << ");\n";
|
||||
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
|
||||
@ -845,11 +843,15 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
|
||||
os << tgfmt("print$0($_printer", &ctx, el->getName());
|
||||
os.indent();
|
||||
for (FormatElement *arg : el->getArguments()) {
|
||||
FormatElement *param = arg;
|
||||
if (auto *ref = dyn_cast<RefDirective>(arg))
|
||||
param = ref->getArg();
|
||||
os << ",\n"
|
||||
<< cast<ParameterElement>(param)->getParam().getAccessorName() << "()";
|
||||
os << ",\n";
|
||||
if (auto *param = dyn_cast<ParameterElement>(arg)) {
|
||||
os << param->getParam().getAccessorName() << "()";
|
||||
} else if (auto *ref = dyn_cast<RefDirective>(arg)) {
|
||||
os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName()
|
||||
<< "()";
|
||||
} else {
|
||||
os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
|
||||
}
|
||||
}
|
||||
os.unindent() << ");\n";
|
||||
}
|
||||
|
@ -129,6 +129,8 @@ FormatToken FormatLexer::lexToken() {
|
||||
return lexLiteral(tokStart);
|
||||
case '$':
|
||||
return lexVariable(tokStart);
|
||||
case '"':
|
||||
return lexString(tokStart);
|
||||
}
|
||||
}
|
||||
|
||||
@ -153,6 +155,17 @@ FormatToken FormatLexer::lexVariable(const char *tokStart) {
|
||||
return formToken(FormatToken::variable, tokStart);
|
||||
}
|
||||
|
||||
FormatToken FormatLexer::lexString(const char *tokStart) {
|
||||
// Lex until another quote, respecting escapes.
|
||||
bool escape = false;
|
||||
while (const char curChar = *curPtr++) {
|
||||
if (!escape && curChar == '"')
|
||||
return formToken(FormatToken::string, tokStart);
|
||||
escape = curChar == '\\';
|
||||
}
|
||||
return emitError(curPtr - 1, "unexpected end of file in string");
|
||||
}
|
||||
|
||||
FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
|
||||
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
|
||||
while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
|
||||
@ -212,6 +225,8 @@ FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
|
||||
FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
|
||||
if (curToken.is(FormatToken::literal))
|
||||
return parseLiteral(ctx);
|
||||
if (curToken.is(FormatToken::string))
|
||||
return parseString(ctx);
|
||||
if (curToken.is(FormatToken::variable))
|
||||
return parseVariable(ctx);
|
||||
if (curToken.isKeyword())
|
||||
@ -253,6 +268,28 @@ FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
|
||||
return create<LiteralElement>(value);
|
||||
}
|
||||
|
||||
FailureOr<FormatElement *> FormatParser::parseString(Context ctx) {
|
||||
FormatToken tok = curToken;
|
||||
SMLoc loc = tok.getLoc();
|
||||
consumeToken();
|
||||
|
||||
if (ctx != CustomDirectiveContext) {
|
||||
return emitError(
|
||||
loc, "strings may only be used as 'custom' directive arguments");
|
||||
}
|
||||
// Escape the string.
|
||||
std::string value;
|
||||
StringRef contents = tok.getSpelling().drop_front().drop_back();
|
||||
value.reserve(contents.size());
|
||||
bool escape = false;
|
||||
for (char c : contents) {
|
||||
escape = c == '\\';
|
||||
if (!escape)
|
||||
value.push_back(c);
|
||||
}
|
||||
return create<StringElement>(std::move(value));
|
||||
}
|
||||
|
||||
FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
|
||||
FormatToken tok = curToken;
|
||||
SMLoc loc = tok.getLoc();
|
||||
|
@ -78,6 +78,7 @@ public:
|
||||
identifier,
|
||||
literal,
|
||||
variable,
|
||||
string,
|
||||
};
|
||||
|
||||
FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
|
||||
@ -130,10 +131,11 @@ private:
|
||||
/// Return the next character in the stream.
|
||||
int getNextChar();
|
||||
|
||||
/// Lex an identifier, literal, or variable.
|
||||
/// Lex an identifier, literal, variable, or string.
|
||||
FormatToken lexIdentifier(const char *tokStart);
|
||||
FormatToken lexLiteral(const char *tokStart);
|
||||
FormatToken lexVariable(const char *tokStart);
|
||||
FormatToken lexString(const char *tokStart);
|
||||
|
||||
/// Create a token with the current pointer and a start pointer.
|
||||
FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
|
||||
@ -163,7 +165,7 @@ public:
|
||||
virtual ~FormatElement();
|
||||
|
||||
// The top-level kinds of format elements.
|
||||
enum Kind { Literal, Variable, Whitespace, Directive, Optional };
|
||||
enum Kind { Literal, String, Variable, Whitespace, Directive, Optional };
|
||||
|
||||
/// Support LLVM-style RTTI.
|
||||
static bool classof(const FormatElement *el) { return true; }
|
||||
@ -212,6 +214,20 @@ private:
|
||||
StringRef spelling;
|
||||
};
|
||||
|
||||
/// This class represents a raw string that can contain arbitrary C++ code.
|
||||
class StringElement : public FormatElementBase<FormatElement::String> {
|
||||
public:
|
||||
/// Create a string element with the given contents.
|
||||
explicit StringElement(std::string value) : value(std::move(value)) {}
|
||||
|
||||
/// Get the value of the string element.
|
||||
StringRef getValue() const { return value; }
|
||||
|
||||
private:
|
||||
/// The contents of the string.
|
||||
std::string value;
|
||||
};
|
||||
|
||||
/// This class represents a variable element. A variable refers to some part of
|
||||
/// the object being parsed, e.g. an attribute or operand on an operation or a
|
||||
/// parameter on an attribute.
|
||||
@ -447,6 +463,8 @@ protected:
|
||||
FailureOr<FormatElement *> parseElement(Context ctx);
|
||||
/// Parse a literal.
|
||||
FailureOr<FormatElement *> parseLiteral(Context ctx);
|
||||
/// Parse a string.
|
||||
FailureOr<FormatElement *> parseString(Context ctx);
|
||||
/// Parse a variable.
|
||||
FailureOr<FormatElement *> parseVariable(Context ctx);
|
||||
/// Parse a directive.
|
||||
|
@ -916,6 +916,13 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
|
||||
body << llvm::formatv("{0}Type", listName);
|
||||
else
|
||||
body << formatv("{0}RawTypes[0]", listName);
|
||||
|
||||
} else if (auto *string = dyn_cast<StringElement>(param)) {
|
||||
FmtContext ctx;
|
||||
ctx.withBuilder("parser.getBuilder()");
|
||||
ctx.addSubst("_ctxt", "parser.getContext()");
|
||||
body << tgfmt(string->getValue(), &ctx);
|
||||
|
||||
} else {
|
||||
llvm_unreachable("unknown custom directive parameter");
|
||||
}
|
||||
@ -1715,6 +1722,13 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
|
||||
body << llvm::formatv("({0}() ? {0}().getType() : Type())", name);
|
||||
else
|
||||
body << name << "().getType()";
|
||||
|
||||
} else if (auto *string = dyn_cast<StringElement>(element)) {
|
||||
FmtContext ctx;
|
||||
ctx.withBuilder("parser.getBuilder()");
|
||||
ctx.addSubst("_ctxt", "parser.getContext()");
|
||||
body << tgfmt(string->getValue(), &ctx);
|
||||
|
||||
} else {
|
||||
llvm_unreachable("unknown custom directive parameter");
|
||||
}
|
||||
@ -2826,8 +2840,9 @@ OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
|
||||
LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
|
||||
SMLoc loc, ArrayRef<FormatElement *> arguments) {
|
||||
for (FormatElement *argument : arguments) {
|
||||
if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
|
||||
OperandVariable, RegionVariable, SuccessorVariable>(argument)) {
|
||||
if (!isa<StringElement, RefDirective, TypeDirective, AttrDictDirective,
|
||||
AttributeVariable, OperandVariable, RegionVariable,
|
||||
SuccessorVariable>(argument)) {
|
||||
// TODO: FormatElement should have location info attached.
|
||||
return emitError(loc, "only variables and types may be used as "
|
||||
"parameters to a custom directive");
|
||||
|
Loading…
x
Reference in New Issue
Block a user