[mlir][ods] Support string literals in custom directives

This patch adds support for string literals as `custom` directive
arguments. This can be useful for re-using custom parsers and printers
when arguments have a known value. For example:

```
ParseResult parseTypedAttr(AsmParser &parser, Attribute &attr, Type type) {
  return parser.parseAttribute(attr, type);
}

void printTypedAttr(AsmPrinter &printer, Attribute attr, Type type) {
  return parser.printAttributeWithoutType(attr);
}
```

And in TableGen:

```
def FooOp : ... {
  let arguments = (ins AnyAttr:$a);
  let assemblyFormat = [{ custom<TypedAttr>($a, "$_builder.getI1Type()")
                          attr-dict }];
}

def BarOp : ... {
  let arguments = (ins AnyAttr:$a);
  let assemblyFormat = [{ custom<TypedAttr>($a, "$_builder.getIndexType()")
                          attr-dict }];
}
```

Instead of writing two separate sets of custom parsers and printers.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D131603
This commit is contained in:
Jeff Niu 2022-08-10 13:37:11 -04:00
parent 2ca27206f9
commit a2ad3ec7ac
16 changed files with 274 additions and 154 deletions

View File

@ -895,6 +895,19 @@ void printStringParam(AsmPrinter &printer, StringRef value);
The custom parser is considered to have failed if it returns failure or if any
bound parameters have failure values afterwards.
A string of C++ code can be used as a `custom` directive argument. When
generating the custom parser and printer call, the string is pasted as a
function argument. For example, `parseBar` and `printBar` can be re-used with
a constant integer:
```tablegen
let parameters = (ins "int":$bar);
let assemblyFormat = [{ custom<Bar>($foo, "1") }];
```
The string is pasted verbatim but with substitutions for `$_builder` and
`$_ctxt`. String literals can be used to parameterize custom directives.
### Verification
If the `genVerifyDecl` field is set, additional verification methods are

View File

@ -768,9 +768,9 @@ when generating the C++ code for the format. The `UserDirective` is an
identifier used as a suffix to these two calls, i.e., `custom<MyDirective>(...)`
would result in calls to `parseMyDirective` and `printMyDirective` within the
parser and printer respectively. `Params` may be any combination of variables
(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`.
The type directives must refer to a variable, but that variable need not also be
a parameter to the custom directive.
(i.e. Attribute, Operand, Successor, etc.), type directives, `attr-dict`, and
strings of C++ code. The type directives must refer to a variable, but that
variable need not also be a parameter to the custom directive.
The arguments to the `parse<UserDirective>` method are firstly a reference to
the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters
@ -837,7 +837,16 @@ declarative parameter to `print` method argument is detailed below:
- VariadicOfVariadic: `TypeRangeRange`
* `attr-dict` Directive: `DictionaryAttr`
When a variable is optional, the provided value may be null.
When a variable is optional, the provided value may be null. When a variable is
referenced in a custom directive parameter using `ref`, it is passed in by
value. Referenced variables to `print<UserDirective>` are passed as the same as
bound variables, but referenced variables to `parse<UserDirective>` are passed
like to the printer.
A custom directive can take a string of C++ code as a parameter. The code is
pasted verbatim in the calls to the custom parser and printers, with the
substitutions `$_builder` and `$_ctxt`. String literals can be used to
parameterize custom directives.
#### Optional Groups
@ -1462,7 +1471,7 @@ std::string stringifyMyBitEnum(MyBitEnum symbol) {
if (2u == (2u & val)) { strs.push_back("Bit1"); }
if (4u == (4u & val)) { strs.push_back("Bit2"); }
if (8u == (8u & val)) { strs.push_back("Bit3"); }
return llvm::join(strs, "|");
}

View File

@ -45,8 +45,9 @@ def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
let results = (outs AnyTensor:$result);
let assemblyFormat = [{
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
`:` type($result)
custom<DynamicIndexList>($sizes, $static_sizes,
"ShapedType::kDynamicSize")
attr-dict `:` type($result)
}];
let extraClassDeclaration = [{

View File

@ -1023,11 +1023,14 @@ def MemRef_ReinterpretCastOp
let assemblyFormat = [{
$source `to` `offset` `` `:`
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<DynamicIndexList>($offsets, $static_offsets,
"ShapedType::kDynamicStrideOrOffset")
`` `,` `sizes` `` `:`
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) `` `,` `strides`
`` `:`
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
custom<DynamicIndexList>($sizes, $static_sizes,
"ShapedType::kDynamicSize")
`` `,` `strides` `` `:`
custom<DynamicIndexList>($strides, $static_strides,
"ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `to` type($result)
}];
@ -1586,9 +1589,12 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
let assemblyFormat = [{
$source ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
custom<DynamicIndexList>($offsets, $static_offsets,
"ShapedType::kDynamicStrideOrOffset")
custom<DynamicIndexList>($sizes, $static_sizes,
"ShapedType::kDynamicSize")
custom<DynamicIndexList>($strides, $static_strides,
"ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `to` type($result)
}];

View File

@ -219,11 +219,11 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
To disambiguate, the inference helpers `inferCanonicalRankReducedResultType`
only drop the first unit dimensions, in order:
e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6.
Verification however has access to result type and does not need to infer.
The verifier calls `isRankReducedType(getSource(), getResult())` to
The verifier calls `isRankReducedType(getSource(), getResult())` to
determine whether the result type is rank-reduced from the source type.
This computes a so-called rank-reduction mask, consisting of dropped unit
This computes a so-called rank-reduction mask, consisting of dropped unit
dims, to map the rank-reduced type to the source type by dropping ones:
e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2}
6x1 is a rank-reduced version of 1x6x1 by mask {0}
@ -254,9 +254,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
let assemblyFormat = [{
$source ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
custom<DynamicIndexList>($offsets, $static_offsets,
"ShapedType::kDynamicStrideOrOffset")
custom<DynamicIndexList>($sizes, $static_sizes,
"ShapedType::kDynamicSize")
custom<DynamicIndexList>($strides, $static_strides,
"ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `to` type($result)
}];
@ -298,12 +301,12 @@ def Tensor_ExtractSliceOp : Tensor_OpWithOffsetSizesAndStrides<"extract_slice",
/// tensor type to the result tensor type by dropping unit dims.
llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeRankReductionMask() {
return ::mlir::computeRankReductionMask(getSourceType().getShape(),
return ::mlir::computeRankReductionMask(getSourceType().getShape(),
getType().getShape());
};
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
static RankedTensorType inferResultType(
ShapedType sourceShapedTensorType,
@ -580,9 +583,12 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
let assemblyFormat = [{
$source `into` $dest ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
custom<DynamicIndexList>($offsets, $static_offsets,
"ShapedType::kDynamicStrideOrOffset")
custom<DynamicIndexList>($sizes, $static_sizes,
"ShapedType::kDynamicSize")
custom<DynamicIndexList>($strides, $static_strides,
"ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `into` type($dest)
}];
@ -608,7 +614,7 @@ def Tensor_InsertSliceOp : Tensor_OpWithOffsetSizesAndStrides<"insert_slice", [
RankedTensorType getType() {
return getResult().getType().cast<RankedTensorType>();
}
/// The `dest` type is the same as the result type.
RankedTensorType getDestType() {
return getType();
@ -962,8 +968,10 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
let assemblyFormat = [{
$source
(`nofold` $nofold^)?
`low` `` custom<OperandsOrIntegersSizesList>($low, $static_low)
`high` `` custom<OperandsOrIntegersSizesList>($high, $static_high)
`low` `` custom<DynamicIndexList>($low, $static_low,
"ShapedType::kDynamicSize")
`high` `` custom<DynamicIndexList>($high, $static_high,
"ShapedType::kDynamicSize")
$region attr-dict `:` type($source) `to` type($result)
}];
@ -1069,15 +1077,15 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
// HasParent<"ParallelCombiningOpInterface">
]> {
let summary = [{
Specify the tensor slice update of a single thread of a parent
Specify the tensor slice update of a single thread of a parent
ParallelCombiningOpInterface op.
}];
let description = [{
The `parallel_insert_slice` yields a subset tensor value to its parent
The `parallel_insert_slice` yields a subset tensor value to its parent
ParallelCombiningOpInterface. These subset tensor values are aggregated to
in some unspecified order into a full tensor value returned by the parent
parallel iterating op.
The `parallel_insert_slice` is one such op allowed in the
in some unspecified order into a full tensor value returned by the parent
parallel iterating op.
The `parallel_insert_slice` is one such op allowed in the
ParallelCombiningOpInterface op.
Conflicting writes result in undefined semantics, in that the indices written
@ -1118,12 +1126,12 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
into a memref.subview op.
A parallel_insert_slice operation may additionally specify insertion into a
tensor of higher rank than the source tensor, along dimensions that are
tensor of higher rank than the source tensor, along dimensions that are
statically known to be of size 1.
This rank-altering behavior is not required by the op semantics: this
flexibility allows to progressively drop unit dimensions while lowering
between different flavors of ops on that operate on tensors.
The rank-altering behavior of tensor.parallel_insert_slice matches the
The rank-altering behavior of tensor.parallel_insert_slice matches the
rank-reducing behavior of tensor.insert_slice and tensor.extract_slice.
Verification in the rank-reduced case:
@ -1144,9 +1152,12 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
);
let assemblyFormat = [{
$source `into` $dest ``
custom<OperandsOrIntegersOffsetsOrStridesList>($offsets, $static_offsets)
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes)
custom<OperandsOrIntegersOffsetsOrStridesList>($strides, $static_strides)
custom<DynamicIndexList>($offsets, $static_offsets,
"ShapedType::kDynamicStrideOrOffset")
custom<DynamicIndexList>($sizes, $static_sizes,
"ShapedType::kDynamicSize")
custom<DynamicIndexList>($strides, $static_strides,
"ShapedType::kDynamicStrideOrOffset")
attr-dict `:` type($source) `into` type($dest)
}];
@ -1194,7 +1205,7 @@ def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
"ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
];
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;

View File

@ -84,73 +84,40 @@ namespace mlir {
/// Printer hook for custom directive in assemblyFormat.
///
/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
/// custom<DynamicIndexList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
/// either (1) the static integer value in `integers` if the value is
/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise. This
/// allows idiomatic printing of mixed value and integer attributes in a
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer,
Operation *op,
OperandRange values,
ArrayAttr integers);
/// Printer hook for custom directive in assemblyFormat.
///
/// custom<OperandsOrIntegersSizesList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with
/// either (1) the static integer value in `integers` if the value is
/// ShapedType::kDynamicSize or (2) the next value otherwise. This
/// allows idiomatic printing of mixed value and integer attributes in a
/// list. E.g. `[%arg0, 7, 42, %arg42]`.
void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op,
OperandRange values, ArrayAttr integers);
/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value
/// in `integers` is `dynVal` or (2) the next value otherwise. This allows
/// idiomatic printing of mixed value and integer attributes in a list. E.g.
/// `[%arg0, 7, 42, %arg42]`.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values, ArrayAttr integers,
int64_t dynVal);
/// Pasrer hook for custom directive in assemblyFormat.
///
/// custom<OperandsOrIntegersOffsetsOrStridesList>($values, $integers)
/// custom<DynamicIndexList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
/// either (1) static integer values or (2) SSA values. Fill `integers` with
/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the
/// position of SSA values. Add the parsed SSA values to `values` in-order.
/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer
/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where
/// `dynVal` encodes the position of SSA values. Add the parsed SSA values
/// to `values` in-order.
//
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
ParseResult parseOperandsOrIntegersOffsetsOrStridesList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
ArrayAttr &integers);
/// Pasrer hook for custom directive in assemblyFormat.
///
/// custom<OperandsOrIntegersSizesList>($values, $integers)
///
/// where `values` is of ODS type `Variadic<Index>` and `integers` is of ODS
/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with
/// either (1) static integer values or (2) SSA values. Fill `integers` with
/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the
/// position of SSA values. Add the parsed SSA values to `values` in-order.
//
/// E.g. after parsing "[%arg0, 7, 42, %arg42]":
/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]"
/// 2. `ssa` is filled with "[%arg0, %arg1]".
ParseResult parseOperandsOrIntegersSizesList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
ArrayAttr &integers);
ParseResult
parseDynamicIndexList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
ArrayAttr &integers, int64_t dynVal);
/// Verify that a the `values` has as many elements as the number of entries in
/// `attr` for which `isDynamic` evaluates to true.
LogicalResult verifyListOfOperandsOrIntegers(
Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr,
ValueRange values, llvm::function_ref<bool(int64_t)> isDynamic);
ValueRange values, function_ref<bool(int64_t)> isDynamic);
} // namespace mlir

View File

@ -987,7 +987,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
auto pdlOperationType = pdl::OperationType::get(parser.getContext());
if (parser.parseOperand(target) ||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes,
ShapedType::kDynamicSize) ||
parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
parser.parseOptionalAttrDict(result.attributes))
return ParseResult::failure();
@ -1001,8 +1002,8 @@ ParseResult transform::TileOp::parse(OpAsmParser &parser,
void TileOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(),
getStaticSizes());
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
ShapedType::kDynamicSize);
p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
}

View File

@ -70,45 +70,29 @@ mlir::detail::verifyOffsetSizeAndStrideOp(OffsetSizeAndStrideOpInterface op) {
return success();
}
template <int64_t dynVal>
static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values,
ArrayAttr arrayAttr) {
p << '[';
if (arrayAttr.empty()) {
p << "]";
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values, ArrayAttr integers,
int64_t dynVal) {
printer << '[';
if (integers.empty()) {
printer << "]";
return;
}
unsigned idx = 0;
llvm::interleaveComma(arrayAttr, p, [&](Attribute a) {
llvm::interleaveComma(integers, printer, [&](Attribute a) {
int64_t val = a.cast<IntegerAttr>().getInt();
if (val == dynVal)
p << values[idx++];
printer << values[idx++];
else
p << val;
printer << val;
});
p << ']';
printer << ']';
}
void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p,
Operation *op,
OperandRange values,
ArrayAttr integers) {
return printOperandsOrIntegersListImpl<ShapedType::kDynamicStrideOrOffset>(
p, values, integers);
}
void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op,
OperandRange values,
ArrayAttr integers) {
return printOperandsOrIntegersListImpl<ShapedType::kDynamicSize>(p, values,
integers);
}
template <int64_t dynVal>
static ParseResult parseOperandsOrIntegersImpl(
ParseResult mlir::parseDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
ArrayAttr &integers) {
ArrayAttr &integers, int64_t dynVal) {
if (failed(parser.parseLSquare()))
return failure();
// 0-D.
@ -142,22 +126,6 @@ static ParseResult parseOperandsOrIntegersImpl(
return success();
}
ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
ArrayAttr &integers) {
return parseOperandsOrIntegersImpl<ShapedType::kDynamicStrideOrOffset>(
parser, values, integers);
}
ParseResult mlir::parseOperandsOrIntegersSizesList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
ArrayAttr &integers) {
return parseOperandsOrIntegersImpl<ShapedType::kDynamicSize>(parser, values,
integers);
}
bool mlir::detail::sameOffsetsSizesAndStrides(
OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b,
llvm::function_ref<bool(OpFoldResult, OpFoldResult)> cmp) {

View File

@ -555,3 +555,19 @@ def TypeK : TestType<"TestM"> {
let mnemonic = "type_k";
let assemblyFormat = "$a";
}
// TYPE-LABEL: ::mlir::Type TestNType::parse
// TYPE: parseFoo(
// TYPE-NEXT: _result_a,
// TYPE-NEXT: 1);
// TYPE-LABEL: void TestNType::print
// TYPE: printFoo(
// TYPE-NEXT: getA(),
// TYPE-NEXT: 1);
def TypeL : TestType<"TestN"> {
let parameters = (ins "int":$a);
let mnemonic = "type_l";
let assemblyFormat = [{ custom<Foo>($a, "1") }];
}

View File

@ -403,6 +403,13 @@ def OptionalInvalidP : TestFormat_Op<[{
($arg^):(`test`)
}]>, Arguments<(ins Variadic<I64>:$arg)>;
//===----------------------------------------------------------------------===//
// Strings
//===----------------------------------------------------------------------===//
// CHECK: error: strings may only be used as 'custom' directive arguments
def StringInvalidA : TestFormat_Op<[{ "foo" }]>;
//===----------------------------------------------------------------------===//
// Variables
//===----------------------------------------------------------------------===//

View File

@ -135,6 +135,13 @@ def OptionalValidA : TestFormat_Op<[{
(` ` `` $arg^)? attr-dict
}]>, Arguments<(ins Optional<I32>:$arg)>;
//===----------------------------------------------------------------------===//
// Strings
//===----------------------------------------------------------------------===//
// CHECK-NOT: error
def StringInvalidA : TestFormat_Op<[{ custom<Foo>("foo") attr-dict }]>;
//===----------------------------------------------------------------------===//
// Variables
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,42 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
def TestDialect : Dialect {
let name = "test";
}
class TestFormat_Op<string fmt, list<Trait> traits = []>
: Op<TestDialect, "format_op", traits> {
let assemblyFormat = fmt;
}
//===----------------------------------------------------------------------===//
// Directives
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// custom
// CHECK-LABEL: CustomStringLiteralA::parse
// CHECK: parseFoo({{.*}}, parser.getBuilder().getI1Type())
// CHECK-LABEL: CustomStringLiteralA::print
// CHECK: printFoo({{.*}}, parser.getBuilder().getI1Type())
def CustomStringLiteralA : TestFormat_Op<[{
custom<Foo>("$_builder.getI1Type()") attr-dict
}]>;
// CHECK-LABEL: CustomStringLiteralB::parse
// CHECK: parseFoo({{.*}}, IndexType::get(parser.getContext()))
// CHECK-LABEL: CustomStringLiteralB::print
// CHECK: printFoo({{.*}}, IndexType::get(parser.getContext()))
def CustomStringLiteralB : TestFormat_Op<[{
custom<Foo>("IndexType::get($_ctxt)") attr-dict
}]>;
// CHECK-LABEL: CustomStringLiteralC::parse
// CHECK: parseFoo({{.*}}, parser.getBuilder().getStringAttr("foo"))
// CHECK-LABEL: CustomStringLiteralC::print
// CHECK: printFoo({{.*}}, parser.getBuilder().getStringAttr("foo"))
def CustomStringLiteralC : TestFormat_Op<[{
custom<Foo>("$_builder.getStringAttr(\"foo\")") attr-dict
}]>;

View File

@ -629,14 +629,12 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
os.indent();
for (FormatElement *arg : el->getArguments()) {
os << ",\n";
FormatElement *param;
if (auto *ref = dyn_cast<RefDirective>(arg)) {
os << "*";
param = ref->getArg();
} else {
param = arg;
}
os << "_result_" << cast<ParameterElement>(param)->getName();
if (auto *param = dyn_cast<ParameterElement>(arg))
os << "_result_" << param->getName();
else if (auto *ref = dyn_cast<RefDirective>(arg))
os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
else
os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
}
os.unindent() << ");\n";
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
@ -845,11 +843,15 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
os << tgfmt("print$0($_printer", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
FormatElement *param = arg;
if (auto *ref = dyn_cast<RefDirective>(arg))
param = ref->getArg();
os << ",\n"
<< cast<ParameterElement>(param)->getParam().getAccessorName() << "()";
os << ",\n";
if (auto *param = dyn_cast<ParameterElement>(arg)) {
os << param->getParam().getAccessorName() << "()";
} else if (auto *ref = dyn_cast<RefDirective>(arg)) {
os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName()
<< "()";
} else {
os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
}
}
os.unindent() << ");\n";
}

View File

@ -129,6 +129,8 @@ FormatToken FormatLexer::lexToken() {
return lexLiteral(tokStart);
case '$':
return lexVariable(tokStart);
case '"':
return lexString(tokStart);
}
}
@ -153,6 +155,17 @@ FormatToken FormatLexer::lexVariable(const char *tokStart) {
return formToken(FormatToken::variable, tokStart);
}
FormatToken FormatLexer::lexString(const char *tokStart) {
// Lex until another quote, respecting escapes.
bool escape = false;
while (const char curChar = *curPtr++) {
if (!escape && curChar == '"')
return formToken(FormatToken::string, tokStart);
escape = curChar == '\\';
}
return emitError(curPtr - 1, "unexpected end of file in string");
}
FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
@ -212,6 +225,8 @@ FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
if (curToken.is(FormatToken::literal))
return parseLiteral(ctx);
if (curToken.is(FormatToken::string))
return parseString(ctx);
if (curToken.is(FormatToken::variable))
return parseVariable(ctx);
if (curToken.isKeyword())
@ -253,6 +268,28 @@ FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
return create<LiteralElement>(value);
}
FailureOr<FormatElement *> FormatParser::parseString(Context ctx) {
FormatToken tok = curToken;
SMLoc loc = tok.getLoc();
consumeToken();
if (ctx != CustomDirectiveContext) {
return emitError(
loc, "strings may only be used as 'custom' directive arguments");
}
// Escape the string.
std::string value;
StringRef contents = tok.getSpelling().drop_front().drop_back();
value.reserve(contents.size());
bool escape = false;
for (char c : contents) {
escape = c == '\\';
if (!escape)
value.push_back(c);
}
return create<StringElement>(std::move(value));
}
FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
FormatToken tok = curToken;
SMLoc loc = tok.getLoc();

View File

@ -78,6 +78,7 @@ public:
identifier,
literal,
variable,
string,
};
FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
@ -130,10 +131,11 @@ private:
/// Return the next character in the stream.
int getNextChar();
/// Lex an identifier, literal, or variable.
/// Lex an identifier, literal, variable, or string.
FormatToken lexIdentifier(const char *tokStart);
FormatToken lexLiteral(const char *tokStart);
FormatToken lexVariable(const char *tokStart);
FormatToken lexString(const char *tokStart);
/// Create a token with the current pointer and a start pointer.
FormatToken formToken(FormatToken::Kind kind, const char *tokStart) {
@ -163,7 +165,7 @@ public:
virtual ~FormatElement();
// The top-level kinds of format elements.
enum Kind { Literal, Variable, Whitespace, Directive, Optional };
enum Kind { Literal, String, Variable, Whitespace, Directive, Optional };
/// Support LLVM-style RTTI.
static bool classof(const FormatElement *el) { return true; }
@ -212,6 +214,20 @@ private:
StringRef spelling;
};
/// This class represents a raw string that can contain arbitrary C++ code.
class StringElement : public FormatElementBase<FormatElement::String> {
public:
/// Create a string element with the given contents.
explicit StringElement(std::string value) : value(std::move(value)) {}
/// Get the value of the string element.
StringRef getValue() const { return value; }
private:
/// The contents of the string.
std::string value;
};
/// This class represents a variable element. A variable refers to some part of
/// the object being parsed, e.g. an attribute or operand on an operation or a
/// parameter on an attribute.
@ -447,6 +463,8 @@ protected:
FailureOr<FormatElement *> parseElement(Context ctx);
/// Parse a literal.
FailureOr<FormatElement *> parseLiteral(Context ctx);
/// Parse a string.
FailureOr<FormatElement *> parseString(Context ctx);
/// Parse a variable.
FailureOr<FormatElement *> parseVariable(Context ctx);
/// Parse a directive.

View File

@ -916,6 +916,13 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
body << llvm::formatv("{0}Type", listName);
else
body << formatv("{0}RawTypes[0]", listName);
} else if (auto *string = dyn_cast<StringElement>(param)) {
FmtContext ctx;
ctx.withBuilder("parser.getBuilder()");
ctx.addSubst("_ctxt", "parser.getContext()");
body << tgfmt(string->getValue(), &ctx);
} else {
llvm_unreachable("unknown custom directive parameter");
}
@ -1715,6 +1722,13 @@ static void genCustomDirectiveParameterPrinter(FormatElement *element,
body << llvm::formatv("({0}() ? {0}().getType() : Type())", name);
else
body << name << "().getType()";
} else if (auto *string = dyn_cast<StringElement>(element)) {
FmtContext ctx;
ctx.withBuilder("parser.getBuilder()");
ctx.addSubst("_ctxt", "parser.getContext()");
body << tgfmt(string->getValue(), &ctx);
} else {
llvm_unreachable("unknown custom directive parameter");
}
@ -2826,8 +2840,9 @@ OpFormatParser::parseAttrDictDirective(SMLoc loc, Context context,
LogicalResult OpFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
for (FormatElement *argument : arguments) {
if (!isa<RefDirective, TypeDirective, AttrDictDirective, AttributeVariable,
OperandVariable, RegionVariable, SuccessorVariable>(argument)) {
if (!isa<StringElement, RefDirective, TypeDirective, AttrDictDirective,
AttributeVariable, OperandVariable, RegionVariable,
SuccessorVariable>(argument)) {
// TODO: FormatElement should have location info attached.
return emitError(loc, "only variables and types may be used as "
"parameters to a custom directive");