From 509f181f40f86926a0c264b41fab7777be4ff91e Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 27 Mar 2026 19:41:46 +0100 Subject: [PATCH] [MLIR][TableGen] Fix ArrayRefParameter in struct format roundtrip (#189065) When an ArrayRefParameter (or OptionalArrayRefParameter) appears in a non-last position within a struct() assembly format directive, the printed output is ambiguous: the comma-separated array elements are indistinguishable from the struct-level commas separating key-value pairs. Fix this by wrapping such parameters in square brackets in both the generated printer and parser. The printer emits '[' before and ']' after the array value; the parser calls parseLSquare()/parseRSquare() around the FieldParser call. Parameters with a custom printer or parser are unaffected (the user controls the format in that case). Fixes #156623 Assisted-by: Claude Code --- .../test/Transforms/debug-imported-entity.fir | 2 +- mlir/test/Dialect/LLVMIR/bytecode.mlir | 2 +- mlir/test/Target/LLVMIR/Import/debug-info.ll | 2 +- mlir/test/Target/LLVMIR/llvmir-debug.mlir | 2 +- mlir/test/lib/Dialect/Test/TestAttrDefs.td | 33 +++++++++ .../attr-or-type-format-roundtrip.mlir | 19 ++++- mlir/test/mlir-tblgen/attr-or-type-format.td | 68 ++++++++++++++++++ .../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 71 ++++++++++++++++--- 8 files changed, 185 insertions(+), 14 deletions(-) diff --git a/flang/test/Transforms/debug-imported-entity.fir b/flang/test/Transforms/debug-imported-entity.fir index 246bdd8bf4e9..b49047491d75 100644 --- a/flang/test/Transforms/debug-imported-entity.fir +++ b/flang/test/Transforms/debug-imported-entity.fir @@ -27,5 +27,5 @@ module { // CHECK: #[[MOD:.+]] = #llvm.di_module<{{.*}}name = "foo"{{.*}}> // CHECK: #[[SP_REC:.+]] = #llvm.di_subprogram, isRecSelf = true{{.*}}> // CHECK: #[[IMP_ENTITY:.+]] = #llvm.di_imported_entity -// CHECK: #[[SP:.+]] = #llvm.di_subprogram{{.*}}retainedNodes = #[[IMP_ENTITY]]> +// CHECK: #[[SP:.+]] = #llvm.di_subprogram{{.*}}retainedNodes = [#[[IMP_ENTITY]]]> // CHECK: #llvm.di_global_variable diff --git a/mlir/test/Dialect/LLVMIR/bytecode.mlir b/mlir/test/Dialect/LLVMIR/bytecode.mlir index 821b0ac2196a..b70ded784bc4 100644 --- a/mlir/test/Dialect/LLVMIR/bytecode.mlir +++ b/mlir/test/Dialect/LLVMIR/bytecode.mlir @@ -31,5 +31,5 @@ module { #di_module1 = #llvm.di_module #di_imported_entity = #llvm.di_imported_entity #di_imported_entity1 = #llvm.di_imported_entity -#di_subprogram1 = #llvm.di_subprogram, id = distinct[6]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn", file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type, retainedNodes = #di_imported_entity, #di_imported_entity1> +#di_subprogram1 = #llvm.di_subprogram, id = distinct[6]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn", file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type, retainedNodes = [#di_imported_entity, #di_imported_entity1]> #loc8 = loc(fused<#di_subprogram1>[#loc1]) diff --git a/mlir/test/Target/LLVMIR/Import/debug-info.ll b/mlir/test/Target/LLVMIR/Import/debug-info.ll index 35ab877b8258..bc65188faa48 100644 --- a/mlir/test/Target/LLVMIR/Import/debug-info.ll +++ b/mlir/test/Target/LLVMIR/Import/debug-info.ll @@ -819,7 +819,7 @@ define void @imp_fn() !dbg !12 { ; CHECK-DAG: #[[M:.+]] = #llvm.di_module<{{.*}}name = "mod1"{{.*}}> ; CHECK-DAG: #[[SP_REC:.+]] = #llvm.di_subprogram, isRecSelf = true> ; CHECK-DAG: #[[IE:.+]] = #llvm.di_imported_entity -; CHECK-DAG: #[[SP:.+]] = #llvm.di_subprogram<{{.*}}name = "imp_fn"{{.*}}retainedNodes = #[[IE]]> +; CHECK-DAG: #[[SP:.+]] = #llvm.di_subprogram<{{.*}}name = "imp_fn"{{.*}}retainedNodes = [#[[IE]]]> ; // ----- diff --git a/mlir/test/Target/LLVMIR/llvmir-debug.mlir b/mlir/test/Target/LLVMIR/llvmir-debug.mlir index 0ff2bd0fc497..95895e3e56a0 100644 --- a/mlir/test/Target/LLVMIR/llvmir-debug.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-debug.mlir @@ -394,7 +394,7 @@ llvm.func @imp_fn() { #di_subprogram = #llvm.di_subprogram, recId = distinct[1]<>, compileUnit = #di_compile_unit, scope = #di_file, name = "imp_fn", file = #di_file, subprogramFlags = Definition, type = #di_subroutine_type, - retainedNodes = #di_imported_entity_1, #di_imported_entity_2> + retainedNodes = [#di_imported_entity_1, #di_imported_entity_2]> #loc1 = loc("test.f90":12:14) #loc2 = loc(fused<#di_subprogram>[#loc1]) diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 4cf836425256..ea527c962abe 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -248,6 +248,39 @@ def TestAttrParams: Test_Attr<"TestAttrParams"> { let assemblyFormat = "`<` params `>`"; } +// Test roundtrip of ArrayRefParameter in struct format (issue #156623). +// An ArrayRefParameter without a custom printer in a non-last struct position +// must be wrapped in `[...]` to be parseable. +def TestAttrArrayRefStruct : Test_Attr<"TestAttrArrayRefStruct"> { + let parameters = (ins + ArrayRefParameter<"int64_t">:$elements, + "int64_t":$count + ); + let mnemonic = "arr_struct"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +// Test that ArrayRefParameter in the LAST struct position is NOT wrapped: +// no ambiguity exists because there is no following comma separator. +def TestAttrArrayRefStructLast : Test_Attr<"TestAttrArrayRefStructLast"> { + let parameters = (ins + "int64_t":$count, + ArrayRefParameter<"int64_t">:$elements + ); + let mnemonic = "arr_struct_last"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +// Test roundtrip of OptionalArrayRefParameter in non-last struct position. +def TestAttrOptArrayRefStruct : Test_Attr<"TestAttrOptArrayRefStruct"> { + let parameters = (ins + OptionalArrayRefParameter<"int64_t">:$elements, + "int64_t":$count + ); + let mnemonic = "opt_arr_struct"; + let assemblyFormat = "`<` struct(params) `>`"; +} + // Test types can be parsed/printed. def TestAttrWithTypeParam : Test_Attr<"TestAttrWithTypeParam"> { let parameters = (ins "::mlir::IntegerType":$int_type, diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir index 9bb333f88dab..e7545afa7e4a 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir +++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir @@ -36,7 +36,24 @@ attributes { // CHECK: #test attr_14 = #test, // CHECK: #test - attr_15 = #test + attr_15 = #test, + // Test that ArrayRefParameter in non-last struct position is wrapped in + // brackets to avoid ambiguity with the struct-level comma (issue #156623). + // CHECK: #test.arr_struct + attr_arr_struct = #test.arr_struct, + // ArrayRefParameter in LAST declared/printed position must NOT be wrapped. + // Input must provide count first so elements is parsed last (no ambiguity). + // CHECK: #test.arr_struct_last + attr_arr_struct_last = #test.arr_struct_last, + // Single-element array in non-last position is still wrapped. + // CHECK: #test.arr_struct + attr_arr_struct_single = #test.arr_struct, + // OptionalArrayRefParameter in non-last struct position (present). + // CHECK: #test.opt_arr_struct + attr_opt_arr_struct = #test.opt_arr_struct, + // OptionalArrayRefParameter absent: no key emitted. + // CHECK: #test.opt_arr_struct + attr_opt_arr_struct_absent = #test.opt_arr_struct } // CHECK-LABEL: @test_roundtrip_default_parsers_struct diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td index d364debd4da9..3a464592d038 100644 --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -827,6 +827,74 @@ def AttrT : TestAttr<"TestT"> { let assemblyFormat = "`{` struct($v0, custom($v1), $v2) `}`"; } +/// Test that an `ArrayRefParameter` in a non-last position inside a `struct` +/// directive is wrapped in `[...]` brackets to avoid ambiguity with the +/// struct-level comma separator (issue #156623). +def AttrWithArrayRefInStruct : TestAttr<"AttrWithArrayRefInStruct"> { + let parameters = (ins + ArrayRefParameter<"int64_t">:$elements, + "int64_t":$count + ); + let mnemonic = "arr_struct"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +// ATTR: if (!_seen_elements && _paramKey == "elements") { +// ATTR: // Parse literal '[' +// ATTR: if (odsParser.parseLSquare()) return {}; +// ATTR: _result_elements = ::mlir::FieldParser<::llvm::SmallVector>::parse(odsParser); +// ATTR: // Parse literal ']' +// ATTR: if (odsParser.parseRSquare()) return {}; +// ATTR: } else if (!_seen_count && _paramKey == "count") { +// ATTR: odsPrinter << "elements = "; +// ATTR: odsPrinter << "["; +// ATTR: odsPrinter.printStrippedAttrOrType(getElements()); +// ATTR: odsPrinter << "]"; +// ATTR: odsPrinter << "count = "; + +/// Test that an `ArrayRefParameter` in the LAST struct position is NOT wrapped +/// in brackets (no ambiguity with a following comma). +def AttrWithArrayRefInStructLast : TestAttr<"AttrWithArrayRefInStructLast"> { + let parameters = (ins + "int64_t":$count, + ArrayRefParameter<"int64_t">:$elements + ); + let mnemonic = "arr_struct_last"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +// ATTR: if (!_seen_count && _paramKey == "count") { +// ATTR-NOT: parseLSquare +// ATTR: if (!_seen_elements && _paramKey == "elements") { +// ATTR-NOT: parseLSquare +// ATTR: odsPrinter << "count = "; +// ATTR-NOT: odsPrinter << "["; +// ATTR: odsPrinter << "elements = "; +// ATTR-NOT: odsPrinter << "["; + +/// Test that an `OptionalArrayRefParameter` in a non-last struct position is +/// wrapped in `[...]` brackets (same rule as the non-optional variant). +def AttrWithOptArrayRefInStruct : TestAttr<"AttrWithOptArrayRefInStruct"> { + let parameters = (ins + OptionalArrayRefParameter<"int64_t">:$elements, + "int64_t":$count + ); + let mnemonic = "opt_arr_struct"; + let assemblyFormat = "`<` struct(params) `>`"; +} + +// ATTR: if (!_seen_elements && _paramKey == "elements") { +// ATTR: // Parse literal '[' +// ATTR: if (odsParser.parseLSquare()) return {}; +// ATTR: _result_elements = ::mlir::FieldParser<::llvm::SmallVector>::parse(odsParser); +// ATTR: // Parse literal ']' +// ATTR: if (odsParser.parseRSquare()) return {}; +// ATTR: } else if (!_seen_count && _paramKey == "count") { +// ATTR: odsPrinter << "elements = "; +// ATTR: odsPrinter << "["; +// ATTR: odsPrinter << "]"; +// ATTR: odsPrinter << "count = "; + // DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser) // DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType); // DEFAULT_TYPE_PARSER: if (parseResult.has_value()) { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index 6170e173c39e..a9bca471cf5b 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -93,6 +93,22 @@ static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) { .DefaultUnreachable("unexpected struct element type"); } +/// Returns true if the parameter is an `ArrayRefParameter` or +/// `OptionalArrayRefParameter` without a custom printer or parser. Such +/// parameters use a comma-separated list as their default format, which is +/// ambiguous when used in a `struct` directive followed by other parameters. +static bool isUndelimitedArrayRefParam(const ParameterElement *el) { + // If the parameter has a custom printer or parser, the user controls the + // format and printer/parser symmetry is their responsibility. + if (el->getParam().getPrinter() || el->getParam().getParser()) + return false; + const auto *defInit = dyn_cast(el->getParam().getDef()); + if (!defInit) + return false; + return defInit->getDef()->isSubClassOf("ArrayRefParameter") || + defInit->getDef()->isSubClassOf("OptionalArrayRefParameter"); +} + /// Shorthand functions that can be used with ranged-based conditions. static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); } static bool formatIsOptional(FormatElement *el) { @@ -224,9 +240,10 @@ private: void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os, bool skipGuard = false); /// Generate a printer for comma-separated format elements. - void genCommaSeparatedPrinter(ArrayRef params, - FmtContext &ctx, MethodBody &os, - function_ref extra); + void genCommaSeparatedPrinter( + ArrayRef params, FmtContext &ctx, MethodBody &os, + function_ref extra, + function_ref extraPost = nullptr); /// Generate the printer code for a `params` directive. void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a `struct` directive. @@ -579,13 +596,27 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, os.indent() << "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n"; genLiteralParser("=", ctx, os.indent()); - for (FormatElement *arg : el->getElements()) { + ArrayRef structElems = el->getElements(); + for (auto [idx, arg] : llvm::enumerate(structElems)) { ParameterElement *param = getEncapsulatedParameterElement(arg); os.getStream().printReindented(strfmt(checkParamKey, param->getName())); + // An `ArrayRefParameter` without a custom parser in a non-last position + // uses `[...]` delimiters to avoid ambiguity with the struct-level comma. + bool useBrackets = isa(arg) && + isUndelimitedArrayRefParam(param) && + idx != structElems.size() - 1; + if (useBrackets) { + os.indent(); + genLiteralParser("[", ctx, os); + } if (isa(arg)) genVariableParser(param, ctx, os.indent()); else if (auto *custom = dyn_cast(arg)) genCustomParser(custom, ctx, os.indent()); + if (useBrackets) { + os.unindent(); + genLiteralParser("]", ctx, os); + } os.unindent() << "} else "; // Print the check for duplicate or unknown parameter. } @@ -853,7 +884,8 @@ static void guardOnAnyOptional(FmtContext &ctx, MethodBody &os, void DefFormat::genCommaSeparatedPrinter( ArrayRef args, FmtContext &ctx, MethodBody &os, - function_ref extra) { + function_ref extra, + function_ref extraPost) { // Emit a space if necessary, but only if the struct is present. if (shouldEmitSpace || !lastWasPunctuation) { bool allOptional = llvm::all_of(args, formatIsOptional); @@ -882,6 +914,8 @@ void DefFormat::genCommaSeparatedPrinter( genVariablePrinter(realParam, ctx, os); else if (auto *custom = dyn_cast(arg)) genCustomPrinter(custom, ctx, os); + if (extraPost) + extraPost(arg); if (param->isOptional()) os.unindent() << "}\n"; } @@ -899,10 +933,29 @@ void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os) { - genCommaSeparatedPrinter(el->getElements(), ctx, os, [&](FormatElement *arg) { - ParameterElement *param = getEncapsulatedParameterElement(arg); - os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName()); - }); + ArrayRef elems = el->getElements(); + // An `ArrayRefParameter` without a custom printer in a non-last struct + // position must be wrapped in `[...]` to avoid ambiguity with the + // struct-level comma separator. Track the element index via elemIdx, which is + // incremented once per element in the extraPost callback. + size_t elemIdx = 0; + genCommaSeparatedPrinter( + elems, ctx, os, + [&](FormatElement *arg) { + ParameterElement *param = getEncapsulatedParameterElement(arg); + os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName()); + auto *paramEl = dyn_cast(arg); + if (paramEl && isUndelimitedArrayRefParam(paramEl) && + elemIdx + 1 < elems.size()) + os << tgfmt("$_printer << \"[\";\n", &ctx); + }, + [&](FormatElement *arg) { + auto *paramEl = dyn_cast(arg); + if (paramEl && isUndelimitedArrayRefParam(paramEl) && + elemIdx + 1 < elems.size()) + os << tgfmt("$_printer << \"]\";\n", &ctx); + ++elemIdx; + }); } void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,