[MLIR][TableGen] Fix ArrayRefParameter in struct format roundtrip (#189065)

When an ArrayRefParameter (or OptionalArrayRefParameter) appears in a
non-last position within a struct() assembly format directive, the
printed
output is ambiguous: the comma-separated array elements are
indistinguishable from the struct-level commas separating key-value
pairs.

Fix this by wrapping such parameters in square brackets in both the
generated printer and parser. The printer emits '[' before and ']' after
the array value; the parser calls parseLSquare()/parseRSquare() around
the
FieldParser call. Parameters with a custom printer or parser are
unaffected
(the user controls the format in that case).

Fixes #156623

Assisted-by: Claude Code
This commit is contained in:
Mehdi Amini 2026-03-27 19:41:46 +01:00 committed by GitHub
parent a996f2a8db
commit 509f181f40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 185 additions and 14 deletions

View File

@ -27,5 +27,5 @@ module {
// CHECK: #[[MOD:.+]] = #llvm.di_module<{{.*}}name = "foo"{{.*}}>
// CHECK: #[[SP_REC:.+]] = #llvm.di_subprogram<recId = distinct[[[REC_ID:[0-9]+]]]<>, isRecSelf = true{{.*}}>
// CHECK: #[[IMP_ENTITY:.+]] = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #[[SP_REC]], entity = #[[MOD]]{{.*}}>
// CHECK: #[[SP:.+]] = #llvm.di_subprogram<recId = distinct[[[REC_ID]]]<>{{.*}}retainedNodes = #[[IMP_ENTITY]]>
// CHECK: #[[SP:.+]] = #llvm.di_subprogram<recId = distinct[[[REC_ID]]]<>{{.*}}retainedNodes = [#[[IMP_ENTITY]]]>
// CHECK: #llvm.di_global_variable<scope = #[[SP]], name = "xyz"{{.*}}>

View File

@ -31,5 +31,5 @@ module {
#di_module1 = #llvm.di_module<file = #di_file, scope = #di_compile_unit2, name = "mod2">
#di_imported_entity = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #di_subprogram, entity = #di_module, file = #di_file, line = 1>
#di_imported_entity1 = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #di_subprogram, entity = #di_module1, file = #di_file, line = 1>
#di_subprogram1 = #llvm.di_subprogram<recId = distinct[2]<>, id = distinct[6]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn", file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type, retainedNodes = #di_imported_entity, #di_imported_entity1>
#di_subprogram1 = #llvm.di_subprogram<recId = distinct[2]<>, id = distinct[6]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn", file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type, retainedNodes = [#di_imported_entity, #di_imported_entity1]>
#loc8 = loc(fused<#di_subprogram1>[#loc1])

View File

@ -819,7 +819,7 @@ define void @imp_fn() !dbg !12 {
; CHECK-DAG: #[[M:.+]] = #llvm.di_module<{{.*}}name = "mod1"{{.*}}>
; CHECK-DAG: #[[SP_REC:.+]] = #llvm.di_subprogram<recId = distinct{{.*}}<>, isRecSelf = true>
; CHECK-DAG: #[[IE:.+]] = #llvm.di_imported_entity<tag = DW_TAG_imported_module, scope = #[[SP_REC]], entity = #[[M]]{{.*}}>
; CHECK-DAG: #[[SP:.+]] = #llvm.di_subprogram<{{.*}}name = "imp_fn"{{.*}}retainedNodes = #[[IE]]>
; CHECK-DAG: #[[SP:.+]] = #llvm.di_subprogram<{{.*}}name = "imp_fn"{{.*}}retainedNodes = [#[[IE]]]>
; // -----

View File

@ -394,7 +394,7 @@ llvm.func @imp_fn() {
#di_subprogram = #llvm.di_subprogram<id = distinct[2]<>, recId = distinct[1]<>,
compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn",
file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type,
retainedNodes = #di_imported_entity_1, #di_imported_entity_2>
retainedNodes = [#di_imported_entity_1, #di_imported_entity_2]>
#loc1 = loc("test.f90":12:14)
#loc2 = loc(fused<#di_subprogram>[#loc1])

View File

@ -248,6 +248,39 @@ def TestAttrParams: Test_Attr<"TestAttrParams"> {
let assemblyFormat = "`<` params `>`";
}
// Test roundtrip of ArrayRefParameter in struct format (issue #156623).
// An ArrayRefParameter without a custom printer in a non-last struct position
// must be wrapped in `[...]` to be parseable.
def TestAttrArrayRefStruct : Test_Attr<"TestAttrArrayRefStruct"> {
let parameters = (ins
ArrayRefParameter<"int64_t">:$elements,
"int64_t":$count
);
let mnemonic = "arr_struct";
let assemblyFormat = "`<` struct(params) `>`";
}
// Test that ArrayRefParameter in the LAST struct position is NOT wrapped:
// no ambiguity exists because there is no following comma separator.
def TestAttrArrayRefStructLast : Test_Attr<"TestAttrArrayRefStructLast"> {
let parameters = (ins
"int64_t":$count,
ArrayRefParameter<"int64_t">:$elements
);
let mnemonic = "arr_struct_last";
let assemblyFormat = "`<` struct(params) `>`";
}
// Test roundtrip of OptionalArrayRefParameter in non-last struct position.
def TestAttrOptArrayRefStruct : Test_Attr<"TestAttrOptArrayRefStruct"> {
let parameters = (ins
OptionalArrayRefParameter<"int64_t">:$elements,
"int64_t":$count
);
let mnemonic = "opt_arr_struct";
let assemblyFormat = "`<` struct(params) `>`";
}
// Test types can be parsed/printed.
def TestAttrWithTypeParam : Test_Attr<"TestAttrWithTypeParam"> {
let parameters = (ins "::mlir::IntegerType":$int_type,

View File

@ -36,7 +36,24 @@ attributes {
// CHECK: #test<simple_enum"+">
attr_14 = #test<simple_enum "+">,
// CHECK: #test<simple_enum"dash-separated-sentence">
attr_15 = #test<simple_enum "dash-separated-sentence">
attr_15 = #test<simple_enum "dash-separated-sentence">,
// Test that ArrayRefParameter in non-last struct position is wrapped in
// brackets to avoid ambiguity with the struct-level comma (issue #156623).
// CHECK: #test.arr_struct<elements = [1, 2, 3], count = 42>
attr_arr_struct = #test.arr_struct<count = 42, elements = [1, 2, 3]>,
// ArrayRefParameter in LAST declared/printed position must NOT be wrapped.
// Input must provide count first so elements is parsed last (no ambiguity).
// CHECK: #test.arr_struct_last<count = 5, elements = 1, 2, 3>
attr_arr_struct_last = #test.arr_struct_last<count = 5, elements = 1, 2, 3>,
// Single-element array in non-last position is still wrapped.
// CHECK: #test.arr_struct<elements = [7], count = 1>
attr_arr_struct_single = #test.arr_struct<count = 1, elements = [7]>,
// OptionalArrayRefParameter in non-last struct position (present).
// CHECK: #test.opt_arr_struct<elements = [4, 5], count = 9>
attr_opt_arr_struct = #test.opt_arr_struct<count = 9, elements = [4, 5]>,
// OptionalArrayRefParameter absent: no key emitted.
// CHECK: #test.opt_arr_struct<count = 3>
attr_opt_arr_struct_absent = #test.opt_arr_struct<count = 3>
}
// CHECK-LABEL: @test_roundtrip_default_parsers_struct

View File

@ -827,6 +827,74 @@ def AttrT : TestAttr<"TestT"> {
let assemblyFormat = "`{` struct($v0, custom<NestedCustom>($v1), $v2) `}`";
}
/// Test that an `ArrayRefParameter` in a non-last position inside a `struct`
/// directive is wrapped in `[...]` brackets to avoid ambiguity with the
/// struct-level comma separator (issue #156623).
def AttrWithArrayRefInStruct : TestAttr<"AttrWithArrayRefInStruct"> {
let parameters = (ins
ArrayRefParameter<"int64_t">:$elements,
"int64_t":$count
);
let mnemonic = "arr_struct";
let assemblyFormat = "`<` struct(params) `>`";
}
// ATTR: if (!_seen_elements && _paramKey == "elements") {
// ATTR: // Parse literal '['
// ATTR: if (odsParser.parseLSquare()) return {};
// ATTR: _result_elements = ::mlir::FieldParser<::llvm::SmallVector<int64_t>>::parse(odsParser);
// ATTR: // Parse literal ']'
// ATTR: if (odsParser.parseRSquare()) return {};
// ATTR: } else if (!_seen_count && _paramKey == "count") {
// ATTR: odsPrinter << "elements = ";
// ATTR: odsPrinter << "[";
// ATTR: odsPrinter.printStrippedAttrOrType(getElements());
// ATTR: odsPrinter << "]";
// ATTR: odsPrinter << "count = ";
/// Test that an `ArrayRefParameter` in the LAST struct position is NOT wrapped
/// in brackets (no ambiguity with a following comma).
def AttrWithArrayRefInStructLast : TestAttr<"AttrWithArrayRefInStructLast"> {
let parameters = (ins
"int64_t":$count,
ArrayRefParameter<"int64_t">:$elements
);
let mnemonic = "arr_struct_last";
let assemblyFormat = "`<` struct(params) `>`";
}
// ATTR: if (!_seen_count && _paramKey == "count") {
// ATTR-NOT: parseLSquare
// ATTR: if (!_seen_elements && _paramKey == "elements") {
// ATTR-NOT: parseLSquare
// ATTR: odsPrinter << "count = ";
// ATTR-NOT: odsPrinter << "[";
// ATTR: odsPrinter << "elements = ";
// ATTR-NOT: odsPrinter << "[";
/// Test that an `OptionalArrayRefParameter` in a non-last struct position is
/// wrapped in `[...]` brackets (same rule as the non-optional variant).
def AttrWithOptArrayRefInStruct : TestAttr<"AttrWithOptArrayRefInStruct"> {
let parameters = (ins
OptionalArrayRefParameter<"int64_t">:$elements,
"int64_t":$count
);
let mnemonic = "opt_arr_struct";
let assemblyFormat = "`<` struct(params) `>`";
}
// ATTR: if (!_seen_elements && _paramKey == "elements") {
// ATTR: // Parse literal '['
// ATTR: if (odsParser.parseLSquare()) return {};
// ATTR: _result_elements = ::mlir::FieldParser<::llvm::SmallVector<int64_t>>::parse(odsParser);
// ATTR: // Parse literal ']'
// ATTR: if (odsParser.parseRSquare()) return {};
// ATTR: } else if (!_seen_count && _paramKey == "count") {
// ATTR: odsPrinter << "elements = ";
// ATTR: odsPrinter << "[";
// ATTR: odsPrinter << "]";
// ATTR: odsPrinter << "count = ";
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {

View File

@ -93,6 +93,22 @@ static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) {
.DefaultUnreachable("unexpected struct element type");
}
/// Returns true if the parameter is an `ArrayRefParameter` or
/// `OptionalArrayRefParameter` without a custom printer or parser. Such
/// parameters use a comma-separated list as their default format, which is
/// ambiguous when used in a `struct` directive followed by other parameters.
static bool isUndelimitedArrayRefParam(const ParameterElement *el) {
// If the parameter has a custom printer or parser, the user controls the
// format and printer/parser symmetry is their responsibility.
if (el->getParam().getPrinter() || el->getParam().getParser())
return false;
const auto *defInit = dyn_cast<llvm::DefInit>(el->getParam().getDef());
if (!defInit)
return false;
return defInit->getDef()->isSubClassOf("ArrayRefParameter") ||
defInit->getDef()->isSubClassOf("OptionalArrayRefParameter");
}
/// Shorthand functions that can be used with ranged-based conditions.
static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
static bool formatIsOptional(FormatElement *el) {
@ -224,9 +240,10 @@ private:
void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
bool skipGuard = false);
/// Generate a printer for comma-separated format elements.
void genCommaSeparatedPrinter(ArrayRef<FormatElement *> params,
FmtContext &ctx, MethodBody &os,
function_ref<void(FormatElement *)> extra);
void genCommaSeparatedPrinter(
ArrayRef<FormatElement *> params, FmtContext &ctx, MethodBody &os,
function_ref<void(FormatElement *)> extra,
function_ref<void(FormatElement *)> extraPost = nullptr);
/// Generate the printer code for a `params` directive.
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
@ -579,13 +596,27 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
os.indent()
<< "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
genLiteralParser("=", ctx, os.indent());
for (FormatElement *arg : el->getElements()) {
ArrayRef<FormatElement *> structElems = el->getElements();
for (auto [idx, arg] : llvm::enumerate(structElems)) {
ParameterElement *param = getEncapsulatedParameterElement(arg);
os.getStream().printReindented(strfmt(checkParamKey, param->getName()));
// An `ArrayRefParameter` without a custom parser in a non-last position
// uses `[...]` delimiters to avoid ambiguity with the struct-level comma.
bool useBrackets = isa<ParameterElement>(arg) &&
isUndelimitedArrayRefParam(param) &&
idx != structElems.size() - 1;
if (useBrackets) {
os.indent();
genLiteralParser("[", ctx, os);
}
if (isa<ParameterElement>(arg))
genVariableParser(param, ctx, os.indent());
else if (auto *custom = dyn_cast<CustomDirective>(arg))
genCustomParser(custom, ctx, os.indent());
if (useBrackets) {
os.unindent();
genLiteralParser("]", ctx, os);
}
os.unindent() << "} else ";
// Print the check for duplicate or unknown parameter.
}
@ -853,7 +884,8 @@ static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os,
void DefFormat::genCommaSeparatedPrinter(
ArrayRef<FormatElement *> args, FmtContext &ctx, MethodBody &os,
function_ref<void(FormatElement *)> extra) {
function_ref<void(FormatElement *)> extra,
function_ref<void(FormatElement *)> extraPost) {
// Emit a space if necessary, but only if the struct is present.
if (shouldEmitSpace || !lastWasPunctuation) {
bool allOptional = llvm::all_of(args, formatIsOptional);
@ -882,6 +914,8 @@ void DefFormat::genCommaSeparatedPrinter(
genVariablePrinter(realParam, ctx, os);
else if (auto *custom = dyn_cast<CustomDirective>(arg))
genCustomPrinter(custom, ctx, os);
if (extraPost)
extraPost(arg);
if (param->isOptional())
os.unindent() << "}\n";
}
@ -899,10 +933,29 @@ void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
MethodBody &os) {
genCommaSeparatedPrinter(el->getElements(), ctx, os, [&](FormatElement *arg) {
ParameterElement *param = getEncapsulatedParameterElement(arg);
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
});
ArrayRef<FormatElement *> elems = el->getElements();
// An `ArrayRefParameter` without a custom printer in a non-last struct
// position must be wrapped in `[...]` to avoid ambiguity with the
// struct-level comma separator. Track the element index via elemIdx, which is
// incremented once per element in the extraPost callback.
size_t elemIdx = 0;
genCommaSeparatedPrinter(
elems, ctx, os,
[&](FormatElement *arg) {
ParameterElement *param = getEncapsulatedParameterElement(arg);
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
auto *paramEl = dyn_cast<ParameterElement>(arg);
if (paramEl && isUndelimitedArrayRefParam(paramEl) &&
elemIdx + 1 < elems.size())
os << tgfmt("$_printer << \"[\";\n", &ctx);
},
[&](FormatElement *arg) {
auto *paramEl = dyn_cast<ParameterElement>(arg);
if (paramEl && isUndelimitedArrayRefParam(paramEl) &&
elemIdx + 1 < elems.size())
os << tgfmt("$_printer << \"]\";\n", &ctx);
++elemIdx;
});
}
void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,