[MLIR][TableGen] Fix ArrayRefParameter in struct format roundtrip (#189065)
When an ArrayRefParameter (or OptionalArrayRefParameter) appears in a non-last position within a struct() assembly format directive, the printed output is ambiguous: the comma-separated array elements are indistinguishable from the struct-level commas separating key-value pairs. Fix this by wrapping such parameters in square brackets in both the generated printer and parser. The printer emits '[' before and ']' after the array value; the parser calls parseLSquare()/parseRSquare() around the FieldParser call. Parameters with a custom printer or parser are unaffected (the user controls the format in that case). Fixes #156623 Assisted-by: Claude Code
This commit is contained in:
parent
a996f2a8db
commit
509f181f40
@ -27,5 +27,5 @@ module {
|
||||
// CHECK: #[[MOD:.+]] = #llvm.di_module<{{.*}}name = "foo"{{.*}}>
|
||||
// CHECK: #[[SP_REC:.+]] = #llvm.di_subprogram<recId = distinct[[[REC_ID:[0-9]+]]]<>, isRecSelf = true{{.*}}>
|
||||
// CHECK: #[[IMP_ENTITY:.+]] = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #[[SP_REC]], entity = #[[MOD]]{{.*}}>
|
||||
// CHECK: #[[SP:.+]] = #llvm.di_subprogram<recId = distinct[[[REC_ID]]]<>{{.*}}retainedNodes = #[[IMP_ENTITY]]>
|
||||
// CHECK: #[[SP:.+]] = #llvm.di_subprogram<recId = distinct[[[REC_ID]]]<>{{.*}}retainedNodes = [#[[IMP_ENTITY]]]>
|
||||
// CHECK: #llvm.di_global_variable<scope = #[[SP]], name = "xyz"{{.*}}>
|
||||
|
||||
@ -31,5 +31,5 @@ module {
|
||||
#di_module1 = #llvm.di_module<file = #di_file, scope = #di_compile_unit2, name = "mod2">
|
||||
#di_imported_entity = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #di_subprogram, entity = #di_module, file = #di_file, line = 1>
|
||||
#di_imported_entity1 = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #di_subprogram, entity = #di_module1, file = #di_file, line = 1>
|
||||
#di_subprogram1 = #llvm.di_subprogram<recId = distinct[2]<>, id = distinct[6]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn", file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type, retainedNodes = #di_imported_entity, #di_imported_entity1>
|
||||
#di_subprogram1 = #llvm.di_subprogram<recId = distinct[2]<>, id = distinct[6]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn", file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type, retainedNodes = [#di_imported_entity, #di_imported_entity1]>
|
||||
#loc8 = loc(fused<#di_subprogram1>[#loc1])
|
||||
|
||||
@ -819,7 +819,7 @@ define void @imp_fn() !dbg !12 {
|
||||
; CHECK-DAG: #[[M:.+]] = #llvm.di_module<{{.*}}name = "mod1"{{.*}}>
|
||||
; CHECK-DAG: #[[SP_REC:.+]] = #llvm.di_subprogram<recId = distinct{{.*}}<>, isRecSelf = true>
|
||||
; CHECK-DAG: #[[IE:.+]] = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #[[SP_REC]], entity = #[[M]]{{.*}}>
|
||||
; CHECK-DAG: #[[SP:.+]] = #llvm.di_subprogram<{{.*}}name = "imp_fn"{{.*}}retainedNodes = #[[IE]]>
|
||||
; CHECK-DAG: #[[SP:.+]] = #llvm.di_subprogram<{{.*}}name = "imp_fn"{{.*}}retainedNodes = [#[[IE]]]>
|
||||
|
||||
; // -----
|
||||
|
||||
|
||||
@ -394,7 +394,7 @@ llvm.func @imp_fn() {
|
||||
#di_subprogram = #llvm.di_subprogram<id = distinct[2]<>, recId = distinct[1]<>,
|
||||
compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn",
|
||||
file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type,
|
||||
retainedNodes = #di_imported_entity_1, #di_imported_entity_2>
|
||||
retainedNodes = [#di_imported_entity_1, #di_imported_entity_2]>
|
||||
#loc1 = loc("test.f90":12:14)
|
||||
#loc2 = loc(fused<#di_subprogram>[#loc1])
|
||||
|
||||
|
||||
@ -248,6 +248,39 @@ def TestAttrParams: Test_Attr<"TestAttrParams"> {
|
||||
let assemblyFormat = "`<` params `>`";
|
||||
}
|
||||
|
||||
// Test roundtrip of ArrayRefParameter in struct format (issue #156623).
|
||||
// An ArrayRefParameter without a custom printer in a non-last struct position
|
||||
// must be wrapped in `[...]` to be parseable.
|
||||
def TestAttrArrayRefStruct : Test_Attr<"TestAttrArrayRefStruct"> {
|
||||
let parameters = (ins
|
||||
ArrayRefParameter<"int64_t">:$elements,
|
||||
"int64_t":$count
|
||||
);
|
||||
let mnemonic = "arr_struct";
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
// Test that ArrayRefParameter in the LAST struct position is NOT wrapped:
|
||||
// no ambiguity exists because there is no following comma separator.
|
||||
def TestAttrArrayRefStructLast : Test_Attr<"TestAttrArrayRefStructLast"> {
|
||||
let parameters = (ins
|
||||
"int64_t":$count,
|
||||
ArrayRefParameter<"int64_t">:$elements
|
||||
);
|
||||
let mnemonic = "arr_struct_last";
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
// Test roundtrip of OptionalArrayRefParameter in non-last struct position.
|
||||
def TestAttrOptArrayRefStruct : Test_Attr<"TestAttrOptArrayRefStruct"> {
|
||||
let parameters = (ins
|
||||
OptionalArrayRefParameter<"int64_t">:$elements,
|
||||
"int64_t":$count
|
||||
);
|
||||
let mnemonic = "opt_arr_struct";
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
// Test types can be parsed/printed.
|
||||
def TestAttrWithTypeParam : Test_Attr<"TestAttrWithTypeParam"> {
|
||||
let parameters = (ins "::mlir::IntegerType":$int_type,
|
||||
|
||||
@ -36,7 +36,24 @@ attributes {
|
||||
// CHECK: #test<simple_enum"+">
|
||||
attr_14 = #test<simple_enum "+">,
|
||||
// CHECK: #test<simple_enum"dash-separated-sentence">
|
||||
attr_15 = #test<simple_enum "dash-separated-sentence">
|
||||
attr_15 = #test<simple_enum "dash-separated-sentence">,
|
||||
// Test that ArrayRefParameter in non-last struct position is wrapped in
|
||||
// brackets to avoid ambiguity with the struct-level comma (issue #156623).
|
||||
// CHECK: #test.arr_struct<elements = [1, 2, 3], count = 42>
|
||||
attr_arr_struct = #test.arr_struct<count = 42, elements = [1, 2, 3]>,
|
||||
// ArrayRefParameter in LAST declared/printed position must NOT be wrapped.
|
||||
// Input must provide count first so elements is parsed last (no ambiguity).
|
||||
// CHECK: #test.arr_struct_last<count = 5, elements = 1, 2, 3>
|
||||
attr_arr_struct_last = #test.arr_struct_last<count = 5, elements = 1, 2, 3>,
|
||||
// Single-element array in non-last position is still wrapped.
|
||||
// CHECK: #test.arr_struct<elements = [7], count = 1>
|
||||
attr_arr_struct_single = #test.arr_struct<count = 1, elements = [7]>,
|
||||
// OptionalArrayRefParameter in non-last struct position (present).
|
||||
// CHECK: #test.opt_arr_struct<elements = [4, 5], count = 9>
|
||||
attr_opt_arr_struct = #test.opt_arr_struct<count = 9, elements = [4, 5]>,
|
||||
// OptionalArrayRefParameter absent: no key emitted.
|
||||
// CHECK: #test.opt_arr_struct<count = 3>
|
||||
attr_opt_arr_struct_absent = #test.opt_arr_struct<count = 3>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_roundtrip_default_parsers_struct
|
||||
|
||||
@ -827,6 +827,74 @@ def AttrT : TestAttr<"TestT"> {
|
||||
let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
|
||||
}
|
||||
|
||||
/// Test that an `ArrayRefParameter` in a non-last position inside a `struct`
|
||||
/// directive is wrapped in `[...]` brackets to avoid ambiguity with the
|
||||
/// struct-level comma separator (issue #156623).
|
||||
def AttrWithArrayRefInStruct : TestAttr<"AttrWithArrayRefInStruct"> {
|
||||
let parameters = (ins
|
||||
ArrayRefParameter<"int64_t">:$elements,
|
||||
"int64_t":$count
|
||||
);
|
||||
let mnemonic = "arr_struct";
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
// ATTR: if (!_seen_elements && _paramKey == "elements") {
|
||||
// ATTR: // Parse literal '['
|
||||
// ATTR: if (odsParser.parseLSquare()) return {};
|
||||
// ATTR: _result_elements = ::mlir::FieldParser<::llvm::SmallVector<int64_t>>::parse(odsParser);
|
||||
// ATTR: // Parse literal ']'
|
||||
// ATTR: if (odsParser.parseRSquare()) return {};
|
||||
// ATTR: } else if (!_seen_count && _paramKey == "count") {
|
||||
// ATTR: odsPrinter << "elements = ";
|
||||
// ATTR: odsPrinter << "[";
|
||||
// ATTR: odsPrinter.printStrippedAttrOrType(getElements());
|
||||
// ATTR: odsPrinter << "]";
|
||||
// ATTR: odsPrinter << "count = ";
|
||||
|
||||
/// Test that an `ArrayRefParameter` in the LAST struct position is NOT wrapped
|
||||
/// in brackets (no ambiguity with a following comma).
|
||||
def AttrWithArrayRefInStructLast : TestAttr<"AttrWithArrayRefInStructLast"> {
|
||||
let parameters = (ins
|
||||
"int64_t":$count,
|
||||
ArrayRefParameter<"int64_t">:$elements
|
||||
);
|
||||
let mnemonic = "arr_struct_last";
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
// ATTR: if (!_seen_count && _paramKey == "count") {
|
||||
// ATTR-NOT: parseLSquare
|
||||
// ATTR: if (!_seen_elements && _paramKey == "elements") {
|
||||
// ATTR-NOT: parseLSquare
|
||||
// ATTR: odsPrinter << "count = ";
|
||||
// ATTR-NOT: odsPrinter << "[";
|
||||
// ATTR: odsPrinter << "elements = ";
|
||||
// ATTR-NOT: odsPrinter << "[";
|
||||
|
||||
/// Test that an `OptionalArrayRefParameter` in a non-last struct position is
|
||||
/// wrapped in `[...]` brackets (same rule as the non-optional variant).
|
||||
def AttrWithOptArrayRefInStruct : TestAttr<"AttrWithOptArrayRefInStruct"> {
|
||||
let parameters = (ins
|
||||
OptionalArrayRefParameter<"int64_t">:$elements,
|
||||
"int64_t":$count
|
||||
);
|
||||
let mnemonic = "opt_arr_struct";
|
||||
let assemblyFormat = "`<` struct(params) `>`";
|
||||
}
|
||||
|
||||
// ATTR: if (!_seen_elements && _paramKey == "elements") {
|
||||
// ATTR: // Parse literal '['
|
||||
// ATTR: if (odsParser.parseLSquare()) return {};
|
||||
// ATTR: _result_elements = ::mlir::FieldParser<::llvm::SmallVector<int64_t>>::parse(odsParser);
|
||||
// ATTR: // Parse literal ']'
|
||||
// ATTR: if (odsParser.parseRSquare()) return {};
|
||||
// ATTR: } else if (!_seen_count && _paramKey == "count") {
|
||||
// ATTR: odsPrinter << "elements = ";
|
||||
// ATTR: odsPrinter << "[";
|
||||
// ATTR: odsPrinter << "]";
|
||||
// ATTR: odsPrinter << "count = ";
|
||||
|
||||
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
|
||||
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
|
||||
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
|
||||
|
||||
@ -93,6 +93,22 @@ static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) {
|
||||
.DefaultUnreachable("unexpected struct element type");
|
||||
}
|
||||
|
||||
/// Returns true if the parameter is an `ArrayRefParameter` or
|
||||
/// `OptionalArrayRefParameter` without a custom printer or parser. Such
|
||||
/// parameters use a comma-separated list as their default format, which is
|
||||
/// ambiguous when used in a `struct` directive followed by other parameters.
|
||||
static bool isUndelimitedArrayRefParam(const ParameterElement *el) {
|
||||
// If the parameter has a custom printer or parser, the user controls the
|
||||
// format and printer/parser symmetry is their responsibility.
|
||||
if (el->getParam().getPrinter() || el->getParam().getParser())
|
||||
return false;
|
||||
const auto *defInit = dyn_cast<llvm::DefInit>(el->getParam().getDef());
|
||||
if (!defInit)
|
||||
return false;
|
||||
return defInit->getDef()->isSubClassOf("ArrayRefParameter") ||
|
||||
defInit->getDef()->isSubClassOf("OptionalArrayRefParameter");
|
||||
}
|
||||
|
||||
/// Shorthand functions that can be used with ranged-based conditions.
|
||||
static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
|
||||
static bool formatIsOptional(FormatElement *el) {
|
||||
@ -224,9 +240,10 @@ private:
|
||||
void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
|
||||
bool skipGuard = false);
|
||||
/// Generate a printer for comma-separated format elements.
|
||||
void genCommaSeparatedPrinter(ArrayRef<FormatElement *> params,
|
||||
FmtContext &ctx, MethodBody &os,
|
||||
function_ref<void(FormatElement *)> extra);
|
||||
void genCommaSeparatedPrinter(
|
||||
ArrayRef<FormatElement *> params, FmtContext &ctx, MethodBody &os,
|
||||
function_ref<void(FormatElement *)> extra,
|
||||
function_ref<void(FormatElement *)> extraPost = nullptr);
|
||||
/// Generate the printer code for a `params` directive.
|
||||
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
|
||||
/// Generate the printer code for a `struct` directive.
|
||||
@ -579,13 +596,27 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
|
||||
os.indent()
|
||||
<< "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
|
||||
genLiteralParser("=", ctx, os.indent());
|
||||
for (FormatElement *arg : el->getElements()) {
|
||||
ArrayRef<FormatElement *> structElems = el->getElements();
|
||||
for (auto [idx, arg] : llvm::enumerate(structElems)) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(arg);
|
||||
os.getStream().printReindented(strfmt(checkParamKey, param->getName()));
|
||||
// An `ArrayRefParameter` without a custom parser in a non-last position
|
||||
// uses `[...]` delimiters to avoid ambiguity with the struct-level comma.
|
||||
bool useBrackets = isa<ParameterElement>(arg) &&
|
||||
isUndelimitedArrayRefParam(param) &&
|
||||
idx != structElems.size() - 1;
|
||||
if (useBrackets) {
|
||||
os.indent();
|
||||
genLiteralParser("[", ctx, os);
|
||||
}
|
||||
if (isa<ParameterElement>(arg))
|
||||
genVariableParser(param, ctx, os.indent());
|
||||
else if (auto *custom = dyn_cast<CustomDirective>(arg))
|
||||
genCustomParser(custom, ctx, os.indent());
|
||||
if (useBrackets) {
|
||||
os.unindent();
|
||||
genLiteralParser("]", ctx, os);
|
||||
}
|
||||
os.unindent() << "} else ";
|
||||
// Print the check for duplicate or unknown parameter.
|
||||
}
|
||||
@ -853,7 +884,8 @@ static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
|
||||
|
||||
void DefFormat::genCommaSeparatedPrinter(
|
||||
ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
|
||||
function_ref<void(FormatElement *)> extra) {
|
||||
function_ref<void(FormatElement *)> extra,
|
||||
function_ref<void(FormatElement *)> extraPost) {
|
||||
// Emit a space if necessary, but only if the struct is present.
|
||||
if (shouldEmitSpace || !lastWasPunctuation) {
|
||||
bool allOptional = llvm::all_of(args, formatIsOptional);
|
||||
@ -882,6 +914,8 @@ void DefFormat::genCommaSeparatedPrinter(
|
||||
genVariablePrinter(realParam, ctx, os);
|
||||
else if (auto *custom = dyn_cast<CustomDirective>(arg))
|
||||
genCustomPrinter(custom, ctx, os);
|
||||
if (extraPost)
|
||||
extraPost(arg);
|
||||
if (param->isOptional())
|
||||
os.unindent() << "}\n";
|
||||
}
|
||||
@ -899,10 +933,29 @@ void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
|
||||
|
||||
void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
|
||||
MethodBody &os) {
|
||||
genCommaSeparatedPrinter(el->getElements(), ctx, os, [&](FormatElement *arg) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(arg);
|
||||
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
|
||||
});
|
||||
ArrayRef<FormatElement *> elems = el->getElements();
|
||||
// An `ArrayRefParameter` without a custom printer in a non-last struct
|
||||
// position must be wrapped in `[...]` to avoid ambiguity with the
|
||||
// struct-level comma separator. Track the element index via elemIdx, which is
|
||||
// incremented once per element in the extraPost callback.
|
||||
size_t elemIdx = 0;
|
||||
genCommaSeparatedPrinter(
|
||||
elems, ctx, os,
|
||||
[&](FormatElement *arg) {
|
||||
ParameterElement *param = getEncapsulatedParameterElement(arg);
|
||||
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
|
||||
auto *paramEl = dyn_cast<ParameterElement>(arg);
|
||||
if (paramEl && isUndelimitedArrayRefParam(paramEl) &&
|
||||
elemIdx + 1 < elems.size())
|
||||
os << tgfmt("$_printer << \"[\";\n", &ctx);
|
||||
},
|
||||
[&](FormatElement *arg) {
|
||||
auto *paramEl = dyn_cast<ParameterElement>(arg);
|
||||
if (paramEl && isUndelimitedArrayRefParam(paramEl) &&
|
||||
elemIdx + 1 < elems.size())
|
||||
os << tgfmt("$_printer << \"]\";\n", &ctx);
|
||||
++elemIdx;
|
||||
});
|
||||
}
|
||||
|
||||
void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user