Two related bugs in generate-test-checks.py when a top-level operation
carries attribute alias references (e.g. `#map`, `#map1`) in its
signature:
1. The attribute reference substitution (replacing `#map` with
`#[[$ATTR_0]]`) ran *before* the pending attribute definitions were
processed, so the names were not yet available and the references were
left as-is in the output.
2. CHECK-LABEL lines do not support FileCheck variable references (e.g.
`#[[$ATTR_0]]`), so even after substitution the generated check would be
syntactically wrong.
Fix both issues:
- In the CHECK-LABEL branch, re-apply `process_attribute_references` to
the label prefix and SSA-split rest after flushing pending attribute
definitions, so that names are resolved.
- Split the label prefix at attribute reference boundaries; keep only
the text before the first reference in the CHECK-LABEL line and emit the
remainder on a CHECK-SAME line.
Before:
// CHECK-LABEL: func.func @test() attributes {amap = #map, bmap = #map1}
{
After:
// CHECK-LABEL: func.func @test() attributes {amap =
// CHECK-SAME: #[[$ATTR_0]], bmap = #[[$ATTR_1]]} {
Fixes #162310
Assisted-by: Claude Code
533 lines
20 KiB
Python
Executable File
533 lines
20 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""A script to generate FileCheck statements for mlir unit tests.
|
||
|
||
This script is a utility to add FileCheck patterns to an mlir file.
|
||
|
||
NOTE: The input .mlir is expected to be the output from the parser, not a
|
||
stripped down variant.
|
||
|
||
Example usage:
|
||
$ generate-test-checks.py foo.mlir
|
||
$ mlir-opt foo.mlir -transformation | generate-test-checks.py
|
||
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir
|
||
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i
|
||
$ mlir-opt foo.mlir -transformation | generate-test-checks.py --source foo.mlir -i --source_delim_regex='gpu.func @'
|
||
|
||
The script will heuristically generate CHECK/CHECK-LABEL commands for each line
|
||
within the file. By default this script will also try to insert string
|
||
substitution blocks for all SSA value names. If --source file is specified, the
|
||
script will attempt to insert the generated CHECKs to the source file by looking
|
||
for line positions matched by --source_delim_regex.
|
||
|
||
The script is designed to make adding checks to a test case fast, it is *not*
|
||
designed to be authoritative about what constitutes a good test!
|
||
"""
|
||
|
||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
# See https://llvm.org/LICENSE.txt for license information.
|
||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
||
import argparse
|
||
import os # Used to advertise this file's name ("autogenerated_note").
|
||
import re
|
||
import sys
|
||
from collections import Counter
|
||
|
||
ADVERT_BEGIN = "// NOTE: Assertions have been autogenerated by "
|
||
ADVERT_END = """
|
||
// This script is intended to make adding checks to a test case quick and easy.
|
||
// It is *not* authoritative about what constitutes a good test. After using the
|
||
// script, be sure to review and refine the generated checks. For example,
|
||
// CHECK lines should be minimized and named to reflect the test’s intent.
|
||
// For comprehensive guidelines, see:
|
||
// * https://mlir.llvm.org/getting_started/TestingGuide/
|
||
"""
|
||
|
||
|
||
# Regex command to match an SSA identifier.
|
||
SSA_RE_STR = "[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*"
|
||
SSA_RE = re.compile(SSA_RE_STR)
|
||
|
||
# Regex matching `dialect.op_name` (e.g. `vector.transfer_read`).
|
||
SSA_OP_NAME_RE = re.compile(r"\b(?:\s=\s[a-z_]+)[.]([a-z_]+)\b")
|
||
|
||
# Regex matching the left-hand side of an assignment
|
||
SSA_RESULTS_STR = r'\s*(%' + SSA_RE_STR + r')(\s*,\s*(%' + SSA_RE_STR + r'))*\s*='
|
||
SSA_RESULTS_RE = re.compile(SSA_RESULTS_STR)
|
||
|
||
# Regex matching attributes
|
||
ATTR_RE_STR = r'(#[a-zA-Z._-][a-zA-Z0-9._-]*)'
|
||
ATTR_RE = re.compile(ATTR_RE_STR)
|
||
|
||
# Regex matching the left-hand side of an attribute definition
|
||
ATTR_DEF_RE_STR = r'\s*' + ATTR_RE_STR + r'\s*='
|
||
ATTR_DEF_RE = re.compile(ATTR_DEF_RE_STR)
|
||
|
||
# Regex matching a FileCheck attribute variable reference produced by this script,
|
||
# e.g. #[[$ATTR_0]] or #[[$ATTR_0:.+]]. Used to detect references that cannot
|
||
# appear in a CHECK-LABEL line (CHECK-LABEL does not support variable references).
|
||
ATTR_REF_IN_LABEL_RE = re.compile(r"(#\[\[\$[^\]]*\]\])")
|
||
|
||
|
||
# Class used to generate and manage string substitution blocks for SSA value
|
||
# names.
|
||
class VariableNamer:
|
||
def __init__(self, variable_names):
|
||
self.scopes = []
|
||
# Counter for generic FileCHeck names, e.g. VAL_#N
|
||
self.name_counter = 0
|
||
# Counters for FileCheck names derived from Op names, e.g.
|
||
# TRANSFER_READ_#N (based on `vector.transfer_read`). Note, there's a
|
||
# dedicated counter for every Op type present in the input.
|
||
self.op_name_counter = Counter()
|
||
|
||
# Number of variable names to still generate in parent scope
|
||
self.generate_in_parent_scope_left = 0
|
||
|
||
# Parse variable names
|
||
self.variable_names = [name.upper() for name in variable_names.split(',')]
|
||
self.used_variable_names = set()
|
||
|
||
# Generate the following 'n' variable names in the parent scope.
|
||
def generate_in_parent_scope(self, n):
|
||
self.generate_in_parent_scope_left = n
|
||
|
||
# Generate a substitution name for the given ssa value name.
|
||
def generate_name(self, source_variable_name, use_ssa_name, op_name=""):
|
||
|
||
# Compute variable name
|
||
variable_name = (
|
||
self.variable_names.pop(0) if len(self.variable_names) > 0 else ""
|
||
)
|
||
if variable_name == "":
|
||
# If `use_ssa_name` is set, use the MLIR SSA value name to generate
|
||
# a FileCHeck substation string. As FileCheck requires these
|
||
# strings to start with a character, skip MLIR variables starting
|
||
# with a digit (e.g. `%0`).
|
||
#
|
||
# The next fallback option is to use the op name, if the
|
||
# corresponding match succeeds.
|
||
#
|
||
# If neither worked, use a generic name: `VAL_#N`.
|
||
if use_ssa_name and source_variable_name[0].isalpha():
|
||
variable_name = source_variable_name.upper()
|
||
elif op_name != "":
|
||
variable_name = (
|
||
op_name.upper() + "_" + str(self.op_name_counter[op_name])
|
||
)
|
||
self.op_name_counter[op_name] += 1
|
||
else:
|
||
variable_name = "VAL_" + str(self.name_counter)
|
||
self.name_counter += 1
|
||
|
||
# Scope where variable name is saved
|
||
scope = len(self.scopes) - 1
|
||
if self.generate_in_parent_scope_left > 0:
|
||
self.generate_in_parent_scope_left -= 1
|
||
scope = len(self.scopes) - 2
|
||
assert(scope >= 0)
|
||
|
||
# Save variable
|
||
if variable_name in self.used_variable_names:
|
||
raise RuntimeError(variable_name + ': duplicate variable name')
|
||
self.scopes[scope][source_variable_name] = variable_name
|
||
self.used_variable_names.add(variable_name)
|
||
|
||
return variable_name
|
||
|
||
# Push a new variable name scope.
|
||
def push_name_scope(self):
|
||
self.scopes.append({})
|
||
|
||
# Pop the last variable name scope.
|
||
def pop_name_scope(self):
|
||
self.scopes.pop()
|
||
|
||
# Return the level of nesting (number of pushed scopes).
|
||
def num_scopes(self):
|
||
return len(self.scopes)
|
||
|
||
# Reset the counter and used variable names.
|
||
def clear_names(self):
|
||
self.name_counter = 0
|
||
self.used_variable_names = set()
|
||
self.op_name_counter.clear()
|
||
|
||
class AttributeNamer:
|
||
|
||
def __init__(self, attribute_names):
|
||
self.name_counter = 0
|
||
self.attribute_names = [name.upper() for name in attribute_names.split(',')]
|
||
self.map = {}
|
||
self.used_attribute_names = set()
|
||
|
||
# Generate a substitution name for the given attribute name.
|
||
def generate_name(self, source_attribute_name):
|
||
|
||
# Compute FileCheck name
|
||
attribute_name = self.attribute_names.pop(0) if len(self.attribute_names) > 0 else ''
|
||
if attribute_name == '':
|
||
attribute_name = "ATTR_" + str(self.name_counter)
|
||
self.name_counter += 1
|
||
|
||
# Prepend global symbol
|
||
attribute_name = '$' + attribute_name
|
||
|
||
# Save attribute
|
||
if attribute_name in self.used_attribute_names:
|
||
raise RuntimeError(attribute_name + ': duplicate attribute name')
|
||
self.map[source_attribute_name] = attribute_name
|
||
self.used_attribute_names.add(attribute_name)
|
||
return attribute_name
|
||
|
||
# Get the saved substitution name for the given attribute name. If no name
|
||
# has been generated for the given attribute yet, None is returned.
|
||
def get_name(self, source_attribute_name):
|
||
return self.map.get(source_attribute_name)
|
||
|
||
# Return the number of SSA results in a line of type
|
||
# %0, %1, ... = ...
|
||
# The function returns 0 if there are no results.
|
||
def get_num_ssa_results(input_line):
|
||
m = SSA_RESULTS_RE.match(input_line)
|
||
return m.group().count('%') if m else 0
|
||
|
||
|
||
# Process a line of input that has been split at each SSA identifier '%'.
|
||
def process_line(line_chunks, variable_namer, use_ssa_name=False, strict_name_re=False):
|
||
output_line = ""
|
||
|
||
# Process the rest that contained an SSA value name.
|
||
for chunk in line_chunks:
|
||
ssa = SSA_RE.match(chunk)
|
||
op_name_with_dialect = SSA_OP_NAME_RE.search(chunk)
|
||
ssa_name = ssa.group(0) if ssa is not None else ""
|
||
op_name = (
|
||
op_name_with_dialect.group(1) if op_name_with_dialect is not None else ""
|
||
)
|
||
|
||
# Check if an existing variable exists for this name.
|
||
variable = None
|
||
for scope in variable_namer.scopes:
|
||
variable = scope.get(ssa_name)
|
||
if variable is not None:
|
||
break
|
||
|
||
# If one exists, then output the existing name.
|
||
if variable is not None:
|
||
output_line += "%[[" + variable + "]]"
|
||
else:
|
||
# Otherwise, generate a new variable.
|
||
variable = variable_namer.generate_name(ssa_name, use_ssa_name, op_name)
|
||
if strict_name_re:
|
||
# Use stricter regexp for the variable name, if requested.
|
||
# Greedy matching may cause issues with the generic '.*'
|
||
# regexp when the checks are split across several
|
||
# lines (e.g. for CHECK-SAME).
|
||
output_line += "%[[" + variable + ":" + SSA_RE_STR + "]]"
|
||
else:
|
||
output_line += "%[[" + variable + ":.*]]"
|
||
|
||
# Append the non named group.
|
||
output_line += chunk[len(ssa_name) :]
|
||
|
||
return output_line.rstrip() + "\n"
|
||
|
||
|
||
# Process the source file lines. The source file doesn't have to be .mlir.
|
||
def process_source_lines(source_lines, args):
|
||
source_split_re = re.compile(args.source_delim_regex)
|
||
|
||
source_segments = [[]]
|
||
for line in source_lines:
|
||
# Remove previous CHECK lines.
|
||
if line.find(args.check_prefix) != -1:
|
||
continue
|
||
# Segment the file based on --source_delim_regex.
|
||
if source_split_re.search(line):
|
||
source_segments.append([])
|
||
|
||
source_segments[-1].append(line + "\n")
|
||
return source_segments
|
||
|
||
|
||
def process_attribute_definition(line, attribute_namer):
|
||
m = ATTR_DEF_RE.match(line)
|
||
if m:
|
||
attribute_name = attribute_namer.generate_name(m.group(1))
|
||
return (
|
||
"// CHECK: #[["
|
||
+ attribute_name
|
||
+ ":.+]] ="
|
||
# The rest of the line may contain attribute references,
|
||
# so we have to process them.
|
||
+ process_attribute_references(line[len(m.group(0)) :], attribute_namer)
|
||
+ "\n"
|
||
)
|
||
return None
|
||
|
||
def process_attribute_references(line, attribute_namer):
|
||
|
||
output_line = ''
|
||
components = ATTR_RE.split(line)
|
||
for component in components:
|
||
m = ATTR_RE.match(component)
|
||
attribute_name = attribute_namer.get_name(m.group(1)) if m else None
|
||
if attribute_name:
|
||
output_line += f"#[[{attribute_name}]]{component[len(m.group()):]}"
|
||
else:
|
||
output_line += component
|
||
return output_line
|
||
|
||
# Pre-process a line of input to remove any character sequences that will be
|
||
# problematic with FileCheck.
|
||
def preprocess_line(line):
|
||
# Replace any `{{` with escaped replacements. `{{` corresponds to regex
|
||
# checks in FileCheck.
|
||
output_line = line.replace("{{", "{{\\{\\{}}")
|
||
|
||
# Replace any double brackets, '[[' with escaped replacements. '[['
|
||
# corresponds to variable names in FileCheck.
|
||
output_line = output_line.replace("[[", "{{\\[\\[}}")
|
||
|
||
# Replace any single brackets that are followed by an SSA identifier, the
|
||
# identifier will be replace by a variable; Creating the same situation as
|
||
# above.
|
||
output_line = output_line.replace("[%", "{{\\[}}%")
|
||
|
||
return output_line
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
|
||
)
|
||
parser.add_argument(
|
||
"--check-prefix", default="CHECK", help="Prefix to use from check file."
|
||
)
|
||
parser.add_argument(
|
||
"-o", "--output", nargs="?", type=argparse.FileType("w"), default=None
|
||
)
|
||
parser.add_argument(
|
||
"input", nargs="?", type=argparse.FileType("r"), default=sys.stdin
|
||
)
|
||
parser.add_argument(
|
||
"--source",
|
||
type=str,
|
||
help="Print each CHECK chunk before each delimeter line in the source"
|
||
"file, respectively. The delimeter lines are identified by "
|
||
"--source_delim_regex.",
|
||
)
|
||
parser.add_argument("--source_delim_regex", type=str, default="func @")
|
||
parser.add_argument(
|
||
"--starts_from_scope",
|
||
type=int,
|
||
default=1,
|
||
help="Omit the top specified level of content. For example, by default "
|
||
'it omits "module {"',
|
||
)
|
||
parser.add_argument("-i", "--inplace", action="store_true", default=False)
|
||
parser.add_argument(
|
||
"--variable_names",
|
||
type=str,
|
||
default='',
|
||
help="Names to be used in FileCheck regular expression to represent SSA "
|
||
"variables in the order they are encountered. Separate names with commas, "
|
||
"and leave empty entries for default names (e.g.: 'DIM,,SUM,RESULT')")
|
||
parser.add_argument(
|
||
"--attribute_names",
|
||
type=str,
|
||
default='',
|
||
help="Names to be used in FileCheck regular expression to represent "
|
||
"attributes in the order they are defined. Separate names with commas,"
|
||
"commas, and leave empty entries for default names (e.g.: 'MAP0,,,MAP1')")
|
||
parser.add_argument(
|
||
"--strict_name_re",
|
||
type=bool,
|
||
default=False,
|
||
help="Set to true to use stricter regex for CHECK-SAME directives. "
|
||
"Use when Greedy matching causes issues with the generic '.*'",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Open the given input file.
|
||
input_lines = [l.rstrip() for l in args.input]
|
||
args.input.close()
|
||
|
||
# Generate a note used for the generated check file.
|
||
script_name = os.path.basename(__file__)
|
||
autogenerated_note = ADVERT_BEGIN + "utils/" + script_name + "\n" + ADVERT_END
|
||
|
||
source_segments = None
|
||
if args.source:
|
||
with open(args.source, "r") as f:
|
||
raw_source = f.read().replace(autogenerated_note, "")
|
||
raw_source_lines = [l.rstrip() for l in raw_source.splitlines()]
|
||
source_segments = process_source_lines(raw_source_lines, args)
|
||
|
||
if args.inplace:
|
||
assert args.output is None
|
||
output = open(args.source, "w")
|
||
elif args.output is None:
|
||
output = sys.stdout
|
||
else:
|
||
output = args.output
|
||
|
||
output_segments = [[]]
|
||
|
||
# Namers
|
||
variable_namer = VariableNamer(args.variable_names)
|
||
attribute_namer = AttributeNamer(args.attribute_names)
|
||
|
||
# Store attribute definitions to emit at appropriate scope
|
||
pending_attr_defs = []
|
||
|
||
# Process lines
|
||
for input_line in input_lines:
|
||
if not input_line:
|
||
continue
|
||
|
||
# When using `--starts_from_scope=0` to capture module lines, the file
|
||
# split needs to be skipped, otherwise a `CHECK: // -----` is inserted.
|
||
if input_line.startswith("// -----"):
|
||
continue
|
||
|
||
if ATTR_DEF_RE.match(input_line):
|
||
pending_attr_defs.append(input_line)
|
||
continue
|
||
|
||
# Lines with blocks begin with a ^. These lines have a trailing comment
|
||
# that needs to be stripped.
|
||
lstripped_input_line = input_line.lstrip()
|
||
is_block = lstripped_input_line[0] == "^"
|
||
if is_block:
|
||
input_line = input_line.rsplit("//", 1)[0].rstrip()
|
||
|
||
cur_level = variable_namer.num_scopes()
|
||
|
||
# If the line starts with a '}', pop the last name scope.
|
||
if lstripped_input_line[0] == "}":
|
||
variable_namer.pop_name_scope()
|
||
cur_level = variable_namer.num_scopes()
|
||
|
||
# If the line ends with a '{', push a new name scope.
|
||
if input_line[-1] == "{":
|
||
variable_namer.push_name_scope()
|
||
if cur_level == args.starts_from_scope:
|
||
output_segments.append([])
|
||
|
||
# Result SSA values must still be pushed to parent scope
|
||
num_ssa_results = get_num_ssa_results(input_line)
|
||
variable_namer.generate_in_parent_scope(num_ssa_results)
|
||
|
||
# Omit lines at the near top level e.g. "module {".
|
||
if cur_level < args.starts_from_scope:
|
||
continue
|
||
|
||
if len(output_segments[-1]) == 0:
|
||
variable_namer.clear_names()
|
||
|
||
# Preprocess the input to remove any sequences that may be problematic with
|
||
# FileCheck.
|
||
input_line = preprocess_line(input_line)
|
||
|
||
# Process uses of attributes in this line
|
||
input_line = process_attribute_references(input_line, attribute_namer)
|
||
|
||
# Split the line at the each SSA value name.
|
||
ssa_split = input_line.split("%")
|
||
|
||
# If this is a top-level operation use 'CHECK-LABEL', otherwise 'CHECK:'.
|
||
if len(output_segments[-1]) != 0 or not ssa_split[0]:
|
||
output_line = "// " + args.check_prefix + ": "
|
||
# Pad to align with the 'LABEL' statements.
|
||
output_line += " " * len("-LABEL")
|
||
|
||
# Output the first line chunk that does not contain an SSA name.
|
||
output_line += ssa_split[0]
|
||
|
||
# Process the rest of the input line.
|
||
output_line += process_line(ssa_split[1:], variable_namer)
|
||
|
||
else:
|
||
# Emit any pending attribute definitions at the start of this scope.
|
||
# This must happen *before* re-processing the label line's attribute
|
||
# references below, so that names are available for substitution.
|
||
for attr in pending_attr_defs:
|
||
attr_line = process_attribute_definition(attr, attribute_namer)
|
||
if attr_line:
|
||
output_segments[-1].append(attr_line)
|
||
pending_attr_defs.clear()
|
||
|
||
# Re-apply attribute reference substitution now that names have been
|
||
# generated by the pending attribute definitions above. The first call
|
||
# at line 431 may have run before the names were defined.
|
||
label_prefix = process_attribute_references(ssa_split[0], attribute_namer)
|
||
ssa_rest = [
|
||
process_attribute_references(arg, attribute_namer)
|
||
for arg in ssa_split[1:]
|
||
]
|
||
|
||
# CHECK-LABEL does not support FileCheck variable references such as
|
||
# #[[$ATTR_0]]. If the label prefix contains attribute references, split
|
||
# at the first one: keep only the text before it in the CHECK-LABEL line
|
||
# and move the remainder to a following CHECK-SAME line.
|
||
label_attr_parts = ATTR_REF_IN_LABEL_RE.split(label_prefix)
|
||
|
||
output_line = (
|
||
"// "
|
||
+ args.check_prefix
|
||
+ "-LABEL: "
|
||
+ label_attr_parts[0].rstrip()
|
||
+ "\n"
|
||
)
|
||
|
||
# Pad continuation lines to align with the end of the label prefix
|
||
# (capped at 20 chars to avoid excessive indentation).
|
||
label_length = len(label_attr_parts[0])
|
||
pad_depth = label_length if label_length < 21 else 4
|
||
|
||
# Emit any attribute references from the label prefix as CHECK-SAME.
|
||
if len(label_attr_parts) > 1:
|
||
output_line += "// " + args.check_prefix + "-SAME: "
|
||
output_line += " " * pad_depth
|
||
output_line += "".join(label_attr_parts[1:]).rstrip() + "\n"
|
||
|
||
# Process the rest of the input line on separate check lines.
|
||
for argument in ssa_rest:
|
||
output_line += "// " + args.check_prefix + "-SAME: "
|
||
output_line += " " * pad_depth
|
||
|
||
# Process the rest of the line. Use the original SSA name to generate the LIT
|
||
# variable names.
|
||
use_ssa_names = True
|
||
output_line += process_line(
|
||
[argument], variable_namer, use_ssa_names, args.strict_name_re
|
||
)
|
||
|
||
# Append the output line.
|
||
output_segments[-1].append(output_line)
|
||
|
||
output.write(autogenerated_note + "\n")
|
||
|
||
# Write the output.
|
||
if source_segments:
|
||
assert len(output_segments) == len(source_segments)
|
||
for check_segment, source_segment in zip(output_segments, source_segments):
|
||
for line in check_segment:
|
||
output.write(line)
|
||
for line in source_segment:
|
||
output.write(line)
|
||
else:
|
||
for segment in output_segments:
|
||
output.write("\n")
|
||
for output_line in segment:
|
||
output.write(output_line)
|
||
output.write("\n")
|
||
output.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|