[MLIR][Transform] apply_registered_pass: support ListOptions (#144026)
Interpret an option value with multiple values, either in the form of an `ArrayAttr` (either static or passed through a param) or as the multiple attrs associated to a param, as a comma-separated list, i.e. as a ListOption on a pass.
This commit is contained in:
parent
299a55a88f
commit
e00853859e
@ -418,11 +418,14 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
|
|||||||
with options = { "top-down" = false,
|
with options = { "top-down" = false,
|
||||||
"max-iterations" = %max_iter,
|
"max-iterations" = %max_iter,
|
||||||
"test-convergence" = true,
|
"test-convergence" = true,
|
||||||
"max-num-rewrites" = %max_rewrites }
|
"max-num-rewrites" = %max_rewrites }
|
||||||
to %module
|
to %module
|
||||||
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
|
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Options' values which are `ArrayAttr`s are converted to comma-separated
|
||||||
|
lists of options. Likewise for params which associate multiple values.
|
||||||
|
|
||||||
This op first looks for a pass pipeline with the specified name. If no such
|
This op first looks for a pass pipeline with the specified name. If no such
|
||||||
pipeline exists, it looks for a pass with the specified name. If no such
|
pipeline exists, it looks for a pass with the specified name. If no such
|
||||||
pass exists either, this op fails definitely.
|
pass exists either, this op fails definitely.
|
||||||
|
@ -788,46 +788,47 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
|
|||||||
// Obtain a single options-string to pass to the pass(-pipeline) from options
|
// Obtain a single options-string to pass to the pass(-pipeline) from options
|
||||||
// passed in as a dictionary of keys mapping to values which are either
|
// passed in as a dictionary of keys mapping to values which are either
|
||||||
// attributes or param-operands pointing to attributes.
|
// attributes or param-operands pointing to attributes.
|
||||||
|
OperandRange dynamicOptions = getDynamicOptions();
|
||||||
|
|
||||||
std::string options;
|
std::string options;
|
||||||
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
|
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
|
||||||
|
|
||||||
OperandRange dynamicOptions = getDynamicOptions();
|
// A helper to convert an option's attribute value into a corresponding
|
||||||
for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
|
// string representation, with the ability to obtain the attr(s) from a param.
|
||||||
if (idx > 0)
|
std::function<void(Attribute)> appendValueAttr = [&](Attribute valueAttr) {
|
||||||
optionsStream << " "; // Interleave options separator.
|
if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
|
||||||
optionsStream << namedAttribute.getName().str(); // Append the key.
|
// The corresponding value attribute(s) is/are passed in via a param.
|
||||||
optionsStream << "="; // And the key-value separator.
|
|
||||||
|
|
||||||
Attribute valueAttrToAppend;
|
|
||||||
if (auto paramOperandIndex =
|
|
||||||
dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
|
|
||||||
// The corresponding value attribute is passed in via a param.
|
|
||||||
// Obtain the param-operand via its specified index.
|
// Obtain the param-operand via its specified index.
|
||||||
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
|
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
|
||||||
assert(dynamicOptionIdx < dynamicOptions.size() &&
|
assert(dynamicOptionIdx < dynamicOptions.size() &&
|
||||||
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
|
"the number of ParamOperandAttrs in the options DictionaryAttr"
|
||||||
"should be the same as the number of options passed as params");
|
"should be the same as the number of options passed as params");
|
||||||
ArrayRef<Attribute> dynamicOption =
|
ArrayRef<Attribute> attrsAssociatedToParam =
|
||||||
state.getParams(dynamicOptions[dynamicOptionIdx]);
|
state.getParams(dynamicOptions[dynamicOptionIdx]);
|
||||||
if (dynamicOption.size() != 1)
|
// Recursive so as to append all attrs associated to the param.
|
||||||
return emitSilenceableError()
|
llvm::interleave(attrsAssociatedToParam, optionsStream, appendValueAttr,
|
||||||
<< "options passed as a param must have "
|
",");
|
||||||
"a single value associated, param "
|
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
|
||||||
<< dynamicOptionIdx << " associates " << dynamicOption.size();
|
// Recursive so as to append all nested attrs of the array.
|
||||||
valueAttrToAppend = dynamicOption[0];
|
llvm::interleave(arrayAttr, optionsStream, appendValueAttr, ",");
|
||||||
} else {
|
} else if (auto strAttr = dyn_cast<StringAttr>(valueAttr)) {
|
||||||
// Value is a static attribute.
|
// Convert to unquoted string.
|
||||||
valueAttrToAppend = namedAttribute.getValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append string representation of value attribute.
|
|
||||||
if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
|
|
||||||
optionsStream << strAttr.getValue().str();
|
optionsStream << strAttr.getValue().str();
|
||||||
} else {
|
} else {
|
||||||
valueAttrToAppend.print(optionsStream, /*elideType=*/true);
|
// For all other attributes, ask the attr to print itself (without type).
|
||||||
|
valueAttr.print(optionsStream, /*elideType=*/true);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
|
// Convert the options DictionaryAttr into a single string.
|
||||||
|
llvm::interleave(
|
||||||
|
getOptions(), optionsStream,
|
||||||
|
[&](auto namedAttribute) {
|
||||||
|
optionsStream << namedAttribute.getName().str(); // Append the key.
|
||||||
|
optionsStream << "="; // And the key-value separator.
|
||||||
|
appendValueAttr(namedAttribute.getValue()); // And the attr's str repr.
|
||||||
|
},
|
||||||
|
" ");
|
||||||
optionsStream.flush();
|
optionsStream.flush();
|
||||||
|
|
||||||
// Get pass or pass pipeline from registry.
|
// Get pass or pass pipeline from registry.
|
||||||
@ -878,23 +879,30 @@ static ParseResult parseApplyRegisteredPassOptions(
|
|||||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
|
||||||
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
|
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
|
||||||
SmallVector<NamedAttribute> keyValuePairs;
|
SmallVector<NamedAttribute> keyValuePairs;
|
||||||
|
|
||||||
size_t dynamicOptionsIdx = 0;
|
size_t dynamicOptionsIdx = 0;
|
||||||
auto parseKeyValuePair = [&]() -> ParseResult {
|
|
||||||
// Parse items of the form `key = value` where `key` is a bare identifier or
|
|
||||||
// a string and `value` is either an attribute or an operand.
|
|
||||||
|
|
||||||
std::string key;
|
// Helper for allowing parsing of option values which can be of the form:
|
||||||
Attribute valueAttr;
|
// - a normal attribute
|
||||||
if (parser.parseOptionalKeywordOrString(&key))
|
// - an operand (which would be converted to an attr referring to the operand)
|
||||||
return parser.emitError(parser.getCurrentLocation())
|
// - ArrayAttrs containing the foregoing (in correspondence with ListOptions)
|
||||||
<< "expected key to either be an identifier or a string";
|
std::function<ParseResult(Attribute &)> parseValue =
|
||||||
if (key.empty())
|
[&](Attribute &valueAttr) -> ParseResult {
|
||||||
return failure();
|
// Allow for array syntax, e.g. `[0 : i64, %param, true, %other_param]`:
|
||||||
|
if (succeeded(parser.parseOptionalLSquare())) {
|
||||||
|
SmallVector<Attribute> attrs;
|
||||||
|
|
||||||
if (parser.parseEqual())
|
// Recursively parse the array's elements, which might be operands.
|
||||||
return parser.emitError(parser.getCurrentLocation())
|
if (parser.parseCommaSeparatedList(
|
||||||
<< "expected '=' after key in key-value pair";
|
AsmParser::Delimiter::None,
|
||||||
|
[&]() -> ParseResult { return parseValue(attrs.emplace_back()); },
|
||||||
|
" in options dictionary") ||
|
||||||
|
parser.parseRSquare())
|
||||||
|
return failure(); // NB: Attempted parse should've output error message.
|
||||||
|
|
||||||
|
valueAttr = ArrayAttr::get(parser.getContext(), attrs);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
// Parse the value, which can be either an attribute or an operand.
|
// Parse the value, which can be either an attribute or an operand.
|
||||||
OptionalParseResult parsedValueAttr =
|
OptionalParseResult parsedValueAttr =
|
||||||
@ -903,9 +911,7 @@ static ParseResult parseApplyRegisteredPassOptions(
|
|||||||
OpAsmParser::UnresolvedOperand operand;
|
OpAsmParser::UnresolvedOperand operand;
|
||||||
ParseResult parsedOperand = parser.parseOperand(operand);
|
ParseResult parsedOperand = parser.parseOperand(operand);
|
||||||
if (failed(parsedOperand))
|
if (failed(parsedOperand))
|
||||||
return parser.emitError(parser.getCurrentLocation())
|
return failure(); // NB: Attempted parse should've output error message.
|
||||||
<< "expected a valid attribute or operand as value associated "
|
|
||||||
<< "to key '" << key << "'";
|
|
||||||
// To make use of the operand, we need to store it in the options dict.
|
// To make use of the operand, we need to store it in the options dict.
|
||||||
// As SSA-values cannot occur in attributes, what we do instead is store
|
// As SSA-values cannot occur in attributes, what we do instead is store
|
||||||
// an attribute in its place that contains the index of the param-operand,
|
// an attribute in its place that contains the index of the param-operand,
|
||||||
@ -924,7 +930,30 @@ static ParseResult parseApplyRegisteredPassOptions(
|
|||||||
<< "in the generic print format";
|
<< "in the generic print format";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper for `key = value`-pair parsing where `key` is a bare identifier or a
|
||||||
|
// string and `value` looks like either an attribute or an operand-in-an-attr.
|
||||||
|
std::function<ParseResult()> parseKeyValuePair = [&]() -> ParseResult {
|
||||||
|
std::string key;
|
||||||
|
Attribute valueAttr;
|
||||||
|
|
||||||
|
if (failed(parser.parseOptionalKeywordOrString(&key)) || key.empty())
|
||||||
|
return parser.emitError(parser.getCurrentLocation())
|
||||||
|
<< "expected key to either be an identifier or a string";
|
||||||
|
|
||||||
|
if (failed(parser.parseEqual()))
|
||||||
|
return parser.emitError(parser.getCurrentLocation())
|
||||||
|
<< "expected '=' after key in key-value pair";
|
||||||
|
|
||||||
|
if (failed(parseValue(valueAttr)))
|
||||||
|
return parser.emitError(parser.getCurrentLocation())
|
||||||
|
<< "expected a valid attribute or operand as value associated "
|
||||||
|
<< "to key '" << key << "'";
|
||||||
|
|
||||||
keyValuePairs.push_back(NamedAttribute(key, valueAttr));
|
keyValuePairs.push_back(NamedAttribute(key, valueAttr));
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -951,16 +980,27 @@ static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
|
|||||||
if (options.empty())
|
if (options.empty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
std::function<void(Attribute)> printOptionValue = [&](Attribute valueAttr) {
|
||||||
|
if (auto paramOperandAttr =
|
||||||
|
dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
|
||||||
|
// Resolve index of param-operand to its actual SSA-value and print that.
|
||||||
|
printer.printOperand(
|
||||||
|
dynamicOptions[paramOperandAttr.getIndex().getInt()]);
|
||||||
|
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
|
||||||
|
// This case is so that ArrayAttr-contained operands are pretty-printed.
|
||||||
|
printer << "[";
|
||||||
|
llvm::interleaveComma(arrayAttr, printer, printOptionValue);
|
||||||
|
printer << "]";
|
||||||
|
} else {
|
||||||
|
printer.printAttribute(valueAttr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
printer << "{";
|
printer << "{";
|
||||||
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
|
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
|
||||||
printer << namedAttribute.getName() << " = ";
|
printer << namedAttribute.getName();
|
||||||
Attribute value = namedAttribute.getValue();
|
printer << " = ";
|
||||||
if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
|
printOptionValue(namedAttribute.getValue());
|
||||||
// Resolve index of param-operand to its actual SSA-value and print that.
|
|
||||||
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
|
|
||||||
} else {
|
|
||||||
printer.printAttribute(value);
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
printer << "}";
|
printer << "}";
|
||||||
}
|
}
|
||||||
@ -970,9 +1010,11 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
|
|||||||
// and references to dynamic options in the options dictionary.
|
// and references to dynamic options in the options dictionary.
|
||||||
|
|
||||||
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
|
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
|
||||||
for (NamedAttribute namedAttr : getOptions())
|
|
||||||
if (auto paramOperand =
|
// Helper for option values to mark seen operands as having been seen (once).
|
||||||
dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
|
std::function<LogicalResult(Attribute)> checkOptionValue =
|
||||||
|
[&](Attribute valueAttr) -> LogicalResult {
|
||||||
|
if (auto paramOperand = dyn_cast<transform::ParamOperandAttr>(valueAttr)) {
|
||||||
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
|
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
|
||||||
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
|
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
|
||||||
return emitOpError()
|
return emitOpError()
|
||||||
@ -983,8 +1025,20 @@ LogicalResult transform::ApplyRegisteredPassOp::verify() {
|
|||||||
return emitOpError() << "dynamic option index " << dynamicOptionIdx
|
return emitOpError() << "dynamic option index " << dynamicOptionIdx
|
||||||
<< " is already used in options";
|
<< " is already used in options";
|
||||||
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
|
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
|
||||||
|
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
|
||||||
|
// Recurse into ArrayAttrs as they may contain references to operands.
|
||||||
|
for (auto eltAttr : arrayAttr)
|
||||||
|
if (failed(checkOptionValue(eltAttr)))
|
||||||
|
return failure();
|
||||||
}
|
}
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
|
||||||
|
for (NamedAttribute namedAttr : getOptions())
|
||||||
|
if (failed(checkOptionValue(namedAttr.getValue())))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// All dynamicOptions-params seen in the dict will have been set to null.
|
||||||
for (Value dynamicOption : dynamicOptions)
|
for (Value dynamicOption : dynamicOptions)
|
||||||
if (dynamicOption)
|
if (dynamicOption)
|
||||||
return emitOpError() << "a param operand does not have a corresponding "
|
return emitOpError() << "a param operand does not have a corresponding "
|
||||||
|
@ -219,6 +219,11 @@ class YieldOp(YieldOp):
|
|||||||
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
|
||||||
|
|
||||||
|
|
||||||
|
OptionValueTypes = Union[
|
||||||
|
Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||||
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
|
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -227,12 +232,7 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
|
|||||||
target: Union[Operation, Value, OpView],
|
target: Union[Operation, Value, OpView],
|
||||||
pass_name: Union[str, StringAttr],
|
pass_name: Union[str, StringAttr],
|
||||||
*,
|
*,
|
||||||
options: Optional[
|
options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
|
||||||
Dict[
|
|
||||||
Union[str, StringAttr],
|
|
||||||
Union[Attribute, Value, Operation, OpView, str, int, bool],
|
|
||||||
]
|
|
||||||
] = None,
|
|
||||||
loc=None,
|
loc=None,
|
||||||
ip=None,
|
ip=None,
|
||||||
):
|
):
|
||||||
@ -243,26 +243,32 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
|
|||||||
context = (loc and loc.context) or Context.current
|
context = (loc and loc.context) or Context.current
|
||||||
|
|
||||||
cur_param_operand_idx = 0
|
cur_param_operand_idx = 0
|
||||||
|
|
||||||
|
def option_value_to_attr(value):
|
||||||
|
nonlocal cur_param_operand_idx
|
||||||
|
if isinstance(value, (Value, Operation, OpView)):
|
||||||
|
dynamic_options.append(_get_op_result_or_value(value))
|
||||||
|
cur_param_operand_idx += 1
|
||||||
|
return ParamOperandAttr(cur_param_operand_idx - 1, context)
|
||||||
|
elif isinstance(value, Attribute):
|
||||||
|
return value
|
||||||
|
# The following cases auto-convert Python values to attributes.
|
||||||
|
elif isinstance(value, bool):
|
||||||
|
return BoolAttr.get(value)
|
||||||
|
elif isinstance(value, int):
|
||||||
|
default_int_type = IntegerType.get_signless(64, context)
|
||||||
|
return IntegerAttr.get(default_int_type, value)
|
||||||
|
elif isinstance(value, str):
|
||||||
|
return StringAttr.get(value)
|
||||||
|
elif isinstance(value, Sequence):
|
||||||
|
return ArrayAttr.get([option_value_to_attr(elt) for elt in value])
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported option type: {type(value)}")
|
||||||
|
|
||||||
for key, value in options.items() if options is not None else {}:
|
for key, value in options.items() if options is not None else {}:
|
||||||
if isinstance(key, StringAttr):
|
if isinstance(key, StringAttr):
|
||||||
key = key.value
|
key = key.value
|
||||||
|
options_dict[key] = option_value_to_attr(value)
|
||||||
if isinstance(value, (Value, Operation, OpView)):
|
|
||||||
dynamic_options.append(_get_op_result_or_value(value))
|
|
||||||
options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context)
|
|
||||||
cur_param_operand_idx += 1
|
|
||||||
elif isinstance(value, Attribute):
|
|
||||||
options_dict[key] = value
|
|
||||||
# The following cases auto-convert Python values to attributes.
|
|
||||||
elif isinstance(value, bool):
|
|
||||||
options_dict[key] = BoolAttr.get(value)
|
|
||||||
elif isinstance(value, int):
|
|
||||||
default_int_type = IntegerType.get_signless(64, context)
|
|
||||||
options_dict[key] = IntegerAttr.get(default_int_type, value)
|
|
||||||
elif isinstance(value, str):
|
|
||||||
options_dict[key] = StringAttr.get(value)
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported option type: {type(value)}")
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
result,
|
result,
|
||||||
_get_op_result_or_value(target),
|
_get_op_result_or_value(target),
|
||||||
@ -279,12 +285,7 @@ def apply_registered_pass(
|
|||||||
target: Union[Operation, Value, OpView],
|
target: Union[Operation, Value, OpView],
|
||||||
pass_name: Union[str, StringAttr],
|
pass_name: Union[str, StringAttr],
|
||||||
*,
|
*,
|
||||||
options: Optional[
|
options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None,
|
||||||
Dict[
|
|
||||||
Union[str, StringAttr],
|
|
||||||
Union[Attribute, Value, Operation, OpView, str, int, bool],
|
|
||||||
]
|
|
||||||
] = None,
|
|
||||||
loc=None,
|
loc=None,
|
||||||
ip=None,
|
ip=None,
|
||||||
) -> Value:
|
) -> Value:
|
||||||
|
@ -164,6 +164,128 @@ module attributes {transform.with_named_sequence} {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func private @valid_multiple_values_as_list_option_single_param()
|
||||||
|
module {
|
||||||
|
func.func @valid_multiple_values_as_list_option_single_param() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @a()
|
||||||
|
func.func @a() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func @b()
|
||||||
|
func.func @b() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||||
|
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
|
||||||
|
%symbol_a = transform.param.constant "a" -> !transform.any_param
|
||||||
|
%symbol_b = transform.param.constant "b" -> !transform.any_param
|
||||||
|
%multiple_symbol_names = transform.merge_handles %symbol_a, %symbol_b : !transform.any_param
|
||||||
|
transform.apply_registered_pass "symbol-privatize"
|
||||||
|
with options = { exclude = %multiple_symbol_names } to %2
|
||||||
|
: (!transform.any_op, !transform.any_param) -> !transform.any_op
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func private @valid_array_attr_as_list_option()
|
||||||
|
module {
|
||||||
|
func.func @valid_array_attr_as_list_option() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @a()
|
||||||
|
func.func @a() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func @b()
|
||||||
|
func.func @b() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||||
|
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
|
||||||
|
transform.apply_registered_pass "symbol-privatize"
|
||||||
|
with options = { exclude = ["a", "b"] } to %2
|
||||||
|
: (!transform.any_op) -> !transform.any_op
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func private @valid_array_attr_param_as_list_option()
|
||||||
|
module {
|
||||||
|
func.func @valid_array_attr_param_as_list_option() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @a()
|
||||||
|
func.func @a() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func @b()
|
||||||
|
func.func @b() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||||
|
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
|
||||||
|
%multiple_symbol_names = transform.param.constant ["a","b"] -> !transform.any_param
|
||||||
|
transform.apply_registered_pass "symbol-privatize"
|
||||||
|
with options = { exclude = %multiple_symbol_names } to %2
|
||||||
|
: (!transform.any_op, !transform.any_param) -> !transform.any_op
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func private @valid_multiple_params_as_single_list_option()
|
||||||
|
module {
|
||||||
|
func.func @valid_multiple_params_as_single_list_option() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: func @a()
|
||||||
|
func.func @a() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: func @b()
|
||||||
|
func.func @b() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
module attributes {transform.with_named_sequence} {
|
||||||
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||||
|
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
|
%2 = transform.get_parent_op %1 { deduplicate } : (!transform.any_op) -> !transform.any_op
|
||||||
|
%symbol_a = transform.param.constant "a" -> !transform.any_param
|
||||||
|
%symbol_b = transform.param.constant "b" -> !transform.any_param
|
||||||
|
transform.apply_registered_pass "symbol-privatize"
|
||||||
|
with options = { exclude = [%symbol_a, %symbol_b] } to %2
|
||||||
|
: (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
|
||||||
|
transform.yield
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func.func @invalid_options_as_str() {
|
func.func @invalid_options_as_str() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -203,7 +325,8 @@ func.func @invalid_options_due_to_reserved_attr() {
|
|||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
||||||
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||||
// expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
|
// expected-error @+3 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
|
||||||
|
// expected-error @+2 {{expected a valid attribute or operand as value associated to key 'top-down'}}
|
||||||
%2 = transform.apply_registered_pass "canonicalize"
|
%2 = transform.apply_registered_pass "canonicalize"
|
||||||
with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
|
with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
|
||||||
transform.yield
|
transform.yield
|
||||||
@ -262,26 +385,6 @@ module attributes {transform.with_named_sequence} {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func.func @too_many_pass_option_params() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
|
||||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
|
|
||||||
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
|
||||||
%x = transform.param.constant true -> !transform.any_param
|
|
||||||
%y = transform.param.constant false -> !transform.any_param
|
|
||||||
%topdown_options = transform.merge_handles %x, %y : !transform.any_param
|
|
||||||
// expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
|
|
||||||
transform.apply_registered_pass "canonicalize"
|
|
||||||
with options = { "top-down" = %topdown_options } to %1
|
|
||||||
: (!transform.any_op, !transform.any_param) -> !transform.any_op
|
|
||||||
transform.yield
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
module attributes {transform.with_named_sequence} {
|
module attributes {transform.with_named_sequence} {
|
||||||
// expected-error @below {{trying to schedule a pass on an unsupported operation}}
|
// expected-error @below {{trying to schedule a pass on an unsupported operation}}
|
||||||
// expected-note @below {{target op}}
|
// expected-note @below {{target op}}
|
||||||
|
@ -256,30 +256,45 @@ def testReplicateOp(module: Module):
|
|||||||
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
|
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
|
||||||
|
|
||||||
|
|
||||||
|
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
|
||||||
@run
|
@run
|
||||||
def testApplyRegisteredPassOp(module: Module):
|
def testApplyRegisteredPassOp(module: Module):
|
||||||
|
# CHECK: transform.sequence
|
||||||
sequence = transform.SequenceOp(
|
sequence = transform.SequenceOp(
|
||||||
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
|
||||||
)
|
)
|
||||||
with InsertionPoint(sequence.body):
|
with InsertionPoint(sequence.body):
|
||||||
|
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
|
||||||
mod = transform.ApplyRegisteredPassOp(
|
mod = transform.ApplyRegisteredPassOp(
|
||||||
transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
|
transform.AnyOpType.get(), sequence.bodyTarget, "canonicalize"
|
||||||
)
|
)
|
||||||
|
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
|
||||||
|
# CHECK-SAME: with options = {"top-down" = false}
|
||||||
|
# CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
|
||||||
mod = transform.ApplyRegisteredPassOp(
|
mod = transform.ApplyRegisteredPassOp(
|
||||||
transform.AnyOpType.get(),
|
transform.AnyOpType.get(),
|
||||||
mod.result,
|
mod.result,
|
||||||
"canonicalize",
|
"canonicalize",
|
||||||
options={"top-down": BoolAttr.get(False)},
|
options={"top-down": BoolAttr.get(False)},
|
||||||
)
|
)
|
||||||
|
# CHECK: %[[MAX_ITER:.+]] = transform.param.constant
|
||||||
max_iter = transform.param_constant(
|
max_iter = transform.param_constant(
|
||||||
transform.AnyParamType.get(),
|
transform.AnyParamType.get(),
|
||||||
IntegerAttr.get(IntegerType.get_signless(64), 10),
|
IntegerAttr.get(IntegerType.get_signless(64), 10),
|
||||||
)
|
)
|
||||||
|
# CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
|
||||||
max_rewrites = transform.param_constant(
|
max_rewrites = transform.param_constant(
|
||||||
transform.AnyParamType.get(),
|
transform.AnyParamType.get(),
|
||||||
IntegerAttr.get(IntegerType.get_signless(64), 1),
|
IntegerAttr.get(IntegerType.get_signless(64), 1),
|
||||||
)
|
)
|
||||||
transform.apply_registered_pass(
|
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
|
||||||
|
# NB: MLIR has sorted the dict lexicographically by key:
|
||||||
|
# CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
|
||||||
|
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
|
||||||
|
# CHECK-SAME: "test-convergence" = true,
|
||||||
|
# CHECK-SAME: "top-down" = false}
|
||||||
|
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
|
||||||
|
mod = transform.apply_registered_pass(
|
||||||
transform.AnyOpType.get(),
|
transform.AnyOpType.get(),
|
||||||
mod,
|
mod,
|
||||||
"canonicalize",
|
"canonicalize",
|
||||||
@ -290,19 +305,30 @@ def testApplyRegisteredPassOp(module: Module):
|
|||||||
"max-rewrites": max_rewrites,
|
"max-rewrites": max_rewrites,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
# CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
|
||||||
|
# CHECK-SAME: with options = {"exclude" = ["a", "b"]}
|
||||||
|
# CHECK-SAME: to %{{.*}} : (!transform.any_op) -> !transform.any_op
|
||||||
|
mod = transform.apply_registered_pass(
|
||||||
|
transform.AnyOpType.get(),
|
||||||
|
mod,
|
||||||
|
"symbol-privatize",
|
||||||
|
options={"exclude": ("a", "b")},
|
||||||
|
)
|
||||||
|
# CHECK: %[[SYMBOL_A:.+]] = transform.param.constant
|
||||||
|
symbol_a = transform.param_constant(
|
||||||
|
transform.AnyParamType.get(), StringAttr.get("a")
|
||||||
|
)
|
||||||
|
# CHECK: %[[SYMBOL_B:.+]] = transform.param.constant
|
||||||
|
symbol_b = transform.param_constant(
|
||||||
|
transform.AnyParamType.get(), StringAttr.get("b")
|
||||||
|
)
|
||||||
|
# CHECK: %{{.*}} = apply_registered_pass "symbol-privatize"
|
||||||
|
# CHECK-SAME: with options = {"exclude" = [%[[SYMBOL_A]], %[[SYMBOL_B]]]}
|
||||||
|
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
|
||||||
|
mod = transform.apply_registered_pass(
|
||||||
|
transform.AnyOpType.get(),
|
||||||
|
mod,
|
||||||
|
"symbol-privatize",
|
||||||
|
options={"exclude": (symbol_a, symbol_b)},
|
||||||
|
)
|
||||||
transform.YieldOp()
|
transform.YieldOp()
|
||||||
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
|
|
||||||
# CHECK: transform.sequence
|
|
||||||
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
|
|
||||||
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
|
|
||||||
# CHECK-SAME: with options = {"top-down" = false}
|
|
||||||
# CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
|
|
||||||
# CHECK: %[[MAX_ITER:.+]] = transform.param.constant
|
|
||||||
# CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
|
|
||||||
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
|
|
||||||
# NB: MLIR has sorted the dict lexicographically by key:
|
|
||||||
# CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
|
|
||||||
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
|
|
||||||
# CHECK-SAME: "test-convergence" = true,
|
|
||||||
# CHECK-SAME: "top-down" = false}
|
|
||||||
# CHECK-SAME: to %{{.*}} : (!transform.any_op, !transform.any_param, !transform.any_param) -> !transform.any_op
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user