[MLIR][Transform] apply_registered_pass op's options as a dict (#143159)

Improve ApplyRegisteredPassOp's support for taking options by taking
them as a dict (vs a list of string-valued key-value pairs).

Values of options are provided as either static attributes or as params
(which pass in attributes at interpreter runtime). In either case, the
keys and value attributes are converted to strings and a single
options-string, in the format used on the commandline, is constructed to
pass to the `addToPipeline`-pass API.
This commit is contained in:
Rolf Morel 2025-06-11 17:33:55 +01:00 committed by GitHub
parent ec8d68b59f
commit fe7bf4b90b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 467 additions and 114 deletions

View File

@ -20,6 +20,10 @@ mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls)
mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRTransformDialectEnumIncGen)
add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen)
mlir_tablegen(TransformAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(TransformAttrs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRTransformDialectAttributesIncGen)
add_dependencies(mlir-headers MLIRTransformDialectAttributesIncGen)
add_mlir_dialect(TransformOps transform)
add_mlir_doc(TransformOps TransformOps Dialects/ -gen-op-doc -dialect=transform)

View File

@ -17,4 +17,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformAttrs.h.inc"
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H

View File

@ -10,6 +10,14 @@
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
class Transform_Attr<string name, string attrMnemonic,
list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<Transform_Dialect, name, traits, baseCppClass> {
let mnemonic = attrMnemonic;
}
def PropagateFailuresCase : I32EnumAttrCase<"Propagate", 1, "propagate">;
def SuppressFailuresCase : I32EnumAttrCase<"Suppress", 2, "suppress">;
@ -33,4 +41,15 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
let cppNamespace = "::mlir::transform";
}
def ParamOperandAttr : Transform_Attr<"ParamOperand", "param_operand"> {
let description = [{
Used to refer to a specific param-operand (via its index) from within an
attribute on a transform operation.
}];
let parameters = (ins
"IntegerAttr":$index
);
let assemblyFormat = "`<` `index` `=` $index `>`";
}
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS

View File

@ -19,6 +19,7 @@ def Transform_Dialect : Dialect {
let cppNamespace = "::mlir::transform";
let hasOperationAttrVerify = 1;
let useDefaultAttributePrinterParser = 1;
let extraClassDeclaration = [{
/// Symbol name for the default entry point "named sequence".
constexpr const static ::llvm::StringLiteral

View File

@ -405,10 +405,23 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
let description = [{
This transform applies the specified pass or pass pipeline to the targeted
ops. The name of the pass/pipeline is specified as a string attribute, as
set during pass/pipeline registration. Optionally, pass options may be
specified as (space-separated) string attributes with the option to pass
these attributes via params. The pass options syntax is identical to the one
used with "mlir-opt".
set during pass/pipeline registration.
Optionally, pass options may be specified via a DictionaryAttr. This
dictionary is converted to a string -- formatted `key=value ...` -- which
is expected to be in the exact format used by the pass on the commandline.
Values are either attributes or (SSA-values of) Transform Dialect params.
For example:
```mlir
transform.apply_registered_pass "canonicalize"
with options = { "top-down" = false,
"max-iterations" = %max_iter,
"test-convergence" = true,
"max-num-rewrites" = %max_rewrites }
to %module
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
```
This op first looks for a pass pipeline with the specified name. If no such
pipeline exists, it looks for a pass with the specified name. If no such
@ -422,7 +435,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
}];
let arguments = (ins StrAttr:$pass_name,
DefaultValuedAttr<ArrayAttr, "{}">:$options,
DefaultValuedAttr<DictionaryAttr, "{}">:$options,
Variadic<TransformParamTypeInterface>:$dynamic_options,
TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$result);

View File

@ -8,17 +8,22 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Analysis/CallGraph.h"
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
#ifndef NDEBUG
void transform::detail::checkImplementsTransformOpInterface(
StringRef name, MLIRContext *context) {
@ -66,6 +71,10 @@ void transform::TransformDialect::initialize() {
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
initializeTypes();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
>();
initializeLibraryModule();
}

View File

@ -54,10 +54,11 @@
using namespace mlir;
static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser &parser, ArrayAttr &options,
OpAsmParser &parser, DictionaryAttr &options,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
Operation *op, ArrayAttr options,
Operation *op,
DictionaryAttr options,
ValueRange dynamicOptions);
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
@ -784,41 +785,50 @@ DiagnosedSilenceableFailure
transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// Obtain a single options-string from options passed statically as
// string attributes as well as "dynamically" through params.
std::string options;
OperandRange dynamicOptions = getDynamicOptions();
size_t dynamicOptionsIdx = 0;
for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) {
if (idx > 0)
options += " "; // Interleave options seperator.
// Obtain a single options-string to pass to the pass(-pipeline) from options
// passed in as a dictionary of keys mapping to values which are either
// attributes or param-operands pointing to attributes.
if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) {
options += strAttr.getValue();
} else if (isa<UnitAttr>(optionAttr)) {
assert(dynamicOptionsIdx < dynamicOptions.size() &&
std::string options;
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
OperandRange dynamicOptions = getDynamicOptions();
for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
if (idx > 0)
optionsStream << " "; // Interleave options separator.
optionsStream << namedAttribute.getName().str(); // Append the key.
optionsStream << "="; // And the key-value separator.
Attribute valueAttrToAppend;
if (auto paramOperandIndex =
dyn_cast<transform::ParamOperandAttr>(namedAttribute.getValue())) {
// The corresponding value attribute is passed in via a param.
// Obtain the param-operand via its specified index.
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
assert(dynamicOptionIdx < dynamicOptions.size() &&
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
"should be the same as the number of options passed as params");
ArrayRef<Attribute> dynamicOption =
state.getParams(dynamicOptions[dynamicOptionsIdx++]);
state.getParams(dynamicOptions[dynamicOptionIdx]);
if (dynamicOption.size() != 1)
return emitSilenceableError() << "options passed as a param must have "
"a single value associated, param "
<< dynamicOptionsIdx - 1 << " associates "
<< dynamicOption.size();
if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0])) {
options += dynamicOptionStr.getValue();
} else {
return emitSilenceableError()
<< "options passed as a param must be a string, got "
<< dynamicOption[0];
}
<< "options passed as a param must have "
"a single value associated, param "
<< dynamicOptionIdx << " associates " << dynamicOption.size();
valueAttrToAppend = dynamicOption[0];
} else {
llvm_unreachable(
"expected options element to be either StringAttr or UnitAttr");
// Value is a static attribute.
valueAttrToAppend = namedAttribute.getValue();
}
// Append string representation of value attribute.
if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
optionsStream << strAttr.getValue().str();
} else {
valueAttrToAppend.print(optionsStream, /*elideType=*/true);
}
}
optionsStream.flush();
// Get pass or pass pipeline from registry.
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@ -864,84 +874,121 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
}
static ParseResult parseApplyRegisteredPassOptions(
OpAsmParser &parser, ArrayAttr &options,
OpAsmParser &parser, DictionaryAttr &options,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
auto dynamicOptionMarker = UnitAttr::get(parser.getContext());
SmallVector<Attribute> optionsArray;
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
SmallVector<NamedAttribute> keyValuePairs;
auto parseOperandOrString = [&]() -> OptionalParseResult {
OpAsmParser::UnresolvedOperand operand;
OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand);
if (parsedOperand.has_value()) {
if (failed(parsedOperand.value()))
return failure();
size_t dynamicOptionsIdx = 0;
auto parseKeyValuePair = [&]() -> ParseResult {
// Parse items of the form `key = value` where `key` is a bare identifier or
// a string and `value` is either an attribute or an operand.
std::string key;
Attribute valueAttr;
if (parser.parseOptionalKeywordOrString(&key))
return parser.emitError(parser.getCurrentLocation())
<< "expected key to either be an identifier or a string";
if (key.empty())
return failure();
if (parser.parseEqual())
return parser.emitError(parser.getCurrentLocation())
<< "expected '=' after key in key-value pair";
// Parse the value, which can be either an attribute or an operand.
OptionalParseResult parsedValueAttr =
parser.parseOptionalAttribute(valueAttr);
if (!parsedValueAttr.has_value()) {
OpAsmParser::UnresolvedOperand operand;
ParseResult parsedOperand = parser.parseOperand(operand);
if (failed(parsedOperand))
return parser.emitError(parser.getCurrentLocation())
<< "expected a valid attribute or operand as value associated "
<< "to key '" << key << "'";
// To make use of the operand, we need to store it in the options dict.
// As SSA-values cannot occur in attributes, what we do instead is store
// an attribute in its place that contains the index of the param-operand,
// so that an attr-value associated to the param can be resolved later on.
dynamicOptions.push_back(operand);
optionsArray.push_back(
dynamicOptionMarker); // Placeholder for knowing where to
// inject the dynamic option-as-param.
return success();
auto wrappedIndex = IntegerAttr::get(
IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
valueAttr =
transform::ParamOperandAttr::get(parser.getContext(), wrappedIndex);
} else if (failed(parsedValueAttr.value())) {
return failure(); // NB: Attempted parse should have output error message.
} else if (isa<transform::ParamOperandAttr>(valueAttr)) {
return parser.emitError(parser.getCurrentLocation())
<< "the param_operand attribute is a marker reserved for "
<< "indicating a value will be passed via params and is only used "
<< "in the generic print format";
}
StringAttr stringAttr;
OptionalParseResult parsedStringAttr =
parser.parseOptionalAttribute(stringAttr);
if (parsedStringAttr.has_value()) {
if (failed(parsedStringAttr.value()))
return failure();
optionsArray.push_back(stringAttr);
return success();
}
return std::nullopt;
keyValuePairs.push_back(NamedAttribute(key, valueAttr));
return success();
};
OptionalParseResult parsedOptionsElement = parseOperandOrString();
while (parsedOptionsElement.has_value()) {
if (failed(parsedOptionsElement.value()))
return failure();
parsedOptionsElement = parseOperandOrString();
}
if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Braces,
parseKeyValuePair,
" in options dictionary"))
return failure(); // NB: Attempted parse should have output error message.
if (optionsArray.empty()) {
if (DictionaryAttr::findDuplicate(
keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
.has_value())
return parser.emitError(parser.getCurrentLocation())
<< "expected at least one option (either a string or a param)";
}
options = parser.getBuilder().getArrayAttr(optionsArray);
<< "duplicate keys found in options dictionary";
options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
return success();
}
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
Operation *op, ArrayAttr options,
Operation *op,
DictionaryAttr options,
ValueRange dynamicOptions) {
size_t currentDynamicOptionIdx = 0;
for (auto [idx, optionAttr] : llvm::enumerate(options)) {
if (idx > 0)
printer << " "; // Interleave options separator.
if (options.empty())
return;
if (isa<UnitAttr>(optionAttr))
printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]);
else if (auto strAttr = dyn_cast<StringAttr>(optionAttr))
printer.printAttribute(strAttr);
else
llvm_unreachable("each option should be either a StringAttr or UnitAttr");
}
printer << "{";
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
printer << namedAttribute.getName() << " = ";
Attribute value = namedAttribute.getValue();
if (auto indexAttr = dyn_cast<transform::ParamOperandAttr>(value)) {
// Resolve index of param-operand to its actual SSA-value and print that.
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
} else {
printer.printAttribute(value);
}
});
printer << "}";
}
LogicalResult transform::ApplyRegisteredPassOp::verify() {
size_t numUnitsInOptions = 0;
for (Attribute optionsElement : getOptions()) {
if (isa<UnitAttr>(optionsElement))
numUnitsInOptions++;
else if (!isa<StringAttr>(optionsElement))
return emitOpError() << "expected each option to be either a StringAttr "
<< "or a UnitAttr, got " << optionsElement;
}
// Check that there is a one-to-one correspondence between param operands
// and references to dynamic options in the options dictionary.
if (getDynamicOptions().size() != numUnitsInOptions)
return emitOpError()
<< "expected the same number of options passed as params as "
<< "UnitAttr elements in options ArrayAttr";
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
for (NamedAttribute namedAttr : getOptions())
if (auto paramOperand =
dyn_cast<transform::ParamOperandAttr>(namedAttr.getValue())) {
size_t dynamicOptionIdx = paramOperand.getIndex().getInt();
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
return emitOpError()
<< "dynamic option index " << dynamicOptionIdx
<< " is out of bounds for the number of dynamic options: "
<< dynamicOptions.size();
if (dynamicOptions[dynamicOptionIdx] == nullptr)
return emitOpError() << "dynamic option index " << dynamicOptionIdx
<< " is already used in options";
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
}
for (Value dynamicOption : dynamicOptions)
if (dynamicOption)
return emitOpError() << "a param operand does not have a corresponding "
<< "param_operand attr in the options dict";
return success();
}

View File

@ -18,7 +18,12 @@ try:
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from typing import Optional, Sequence, Union, NewType
from typing import Dict, Optional, Sequence, Union, NewType
@register_attribute_builder("ParamOperandAttr")
def _paramOperandAttr(x: int, context) -> Attribute:
return Attribute.parse(f"#transform.param_operand<index={x}>", context=context)
@_ods_cext.register_operation(_Dialect, replace=True)
@ -214,6 +219,81 @@ class YieldOp(YieldOp):
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
@_ods_cext.register_operation(_Dialect, replace=True)
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
def __init__(
self,
result: Type,
pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
*,
options: Optional[
Dict[
Union[str, StringAttr],
Union[Attribute, Value, Operation, OpView],
]
] = None,
loc=None,
ip=None,
):
options_dict = {}
dynamic_options = []
ParamOperandAttr = AttrBuilder.get("ParamOperandAttr")
context = (loc and loc.context) or Context.current
cur_param_operand_idx = 0
for key, value in options.items() if options is not None else {}:
if isinstance(key, StringAttr):
key = key.value
if isinstance(value, (Value, Operation, OpView)):
dynamic_options.append(_get_op_result_or_value(value))
options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context)
cur_param_operand_idx += 1
elif isinstance(value, Attribute):
options_dict[key] = value
elif isinstance(value, str):
options_dict[key] = StringAttr.get(value)
else:
raise TypeError(f"Unsupported option type: {type(value)}")
if len(options_dict) > 0:
print(options_dict, cur_param_operand_idx)
super().__init__(
result,
pass_name,
dynamic_options,
target=_get_op_result_or_value(target),
options=DictAttr.get(options_dict),
loc=loc,
ip=ip,
)
def apply_registered_pass(
result: Type,
pass_name: Union[str, StringAttr],
target: Union[Operation, Value, OpView],
*,
options: Optional[
Dict[
Union[str, StringAttr],
Union[Attribute, Value, Operation, OpView],
]
] = None,
loc=None,
ip=None,
) -> Value:
return ApplyRegisteredPassOp(
result=result,
pass_name=pass_name,
target=target,
options=options,
loc=loc,
ip=ip,
).result
AnyOpTypeT = NewType("AnyOpType", AnyOpType)

View File

@ -80,7 +80,7 @@ module attributes {transform.with_named_sequence} {
// expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
// expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
transform.apply_registered_pass "canonicalize"
with options = "invalid-option=1" to %1
with options = { "invalid-option" = 1 } to %1
: (!transform.any_op) -> !transform.any_op
transform.yield
}
@ -97,7 +97,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize"
with options = "top-down=false" to %1
with options = { "top-down" = false } to %1
: (!transform.any_op) -> !transform.any_op
transform.yield
}
@ -115,7 +115,7 @@ module attributes {transform.with_named_sequence} {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
//transform.apply_registered_pass "canonicalize" with options = "top-down=false,max-iterations=10" to %1 : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize"
with options = "top-down=false test-convergence=true" to %1
with options = { "top-down" = false, "test-convergence" =true } to %1
: (!transform.any_op) -> !transform.any_op
transform.yield
}
@ -132,7 +132,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_registered_pass "canonicalize"
with options = "top-down=false" "max-iterations=0" to %1
with options = { "top-down" = false, "max-iterations" = 0 } to %1
: (!transform.any_op) -> !transform.any_op
transform.yield
}
@ -148,10 +148,15 @@ func.func @valid_dynamic_pass_options() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
%max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param
%2 = transform.apply_registered_pass "canonicalize"
with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1
%max_iter = transform.param.constant 10 -> !transform.any_param
%max_rewrites = transform.param.constant 1 -> !transform.any_param
%2 = transform.apply_registered_pass
"canonicalize"
with options = { "top-down" = false,
"max-iterations" = %max_iter,
"test-convergence" = true,
"max-num-rewrites" = %max_rewrites }
to %1
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
transform.yield
}
@ -159,7 +164,7 @@ module attributes {transform.with_named_sequence} {
// -----
func.func @invalid_dynamic_options_as_array() {
func.func @invalid_options_as_str() {
return
}
@ -167,34 +172,80 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
// expected-error @+2 {{expected at least one option (either a string or a param)}}
// expected-error @+2 {{expected '{' in options dictionary}}
%2 = transform.apply_registered_pass "canonicalize"
with options = ["top-down=false" %max_iter] to %1
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
func.func @invalid_options_as_pairs() {
func.func @invalid_options_as_pairs_without_braces() {
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+2 {{expected 'to'}}
// expected-error @+2 {{expected '{' in options dictionary}}
%2 = transform.apply_registered_pass "canonicalize"
with options = "top-down=" false to %1
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
with options = "top-down"=false to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
func.func @invalid_pass_option_param() {
func.func @invalid_options_due_to_reserved_attr() {
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+2 {{the param_operand attribute is a marker reserved for indicating a value will be passed via params and is only used in the generic print format}}
%2 = transform.apply_registered_pass "canonicalize"
with options = { "top-down" = #transform.param_operand<index=0> } to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
func.func @invalid_options_due_duplicated_key() {
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+2 {{duplicate keys found in options dictionary}}
%2 = transform.apply_registered_pass "canonicalize"
with options = {"top-down"=false,"top-down"=true} to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
func.func @invalid_options_due_invalid_key() {
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @+2 {{expected key to either be an identifier or a string}}
%2 = transform.apply_registered_pass "canonicalize"
with options = { @label = 0 } to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
// -----
func.func @invalid_pass_option_bare_param() {
return
}
@ -202,7 +253,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%pass_options = transform.param.constant 42 -> !transform.any_param
// expected-error @below {{options passed as a param must be a string, got 42}}
// expected-error @+2 {{expected '{' in options dictionary}}
transform.apply_registered_pass "canonicalize"
with options = %pass_options to %1
: (!transform.any_param, !transform.any_op) -> !transform.any_op
@ -219,12 +270,12 @@ func.func @too_many_pass_option_params() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%x = transform.param.constant "x" -> !transform.any_param
%y = transform.param.constant "y" -> !transform.any_param
%pass_options = transform.merge_handles %x, %y : !transform.any_param
%x = transform.param.constant true -> !transform.any_param
%y = transform.param.constant false -> !transform.any_param
%topdown_options = transform.merge_handles %x, %y : !transform.any_param
// expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}}
transform.apply_registered_pass "canonicalize"
with options = %pass_options to %1
with options = { "top-down" = %topdown_options } to %1
: (!transform.any_param, !transform.any_op) -> !transform.any_op
transform.yield
}
@ -248,3 +299,77 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// -----
/////////////////////////////////////////////////////////////////////
// Check that the following cases are caugh in the generic format. //
/////////////////////////////////////////////////////////////////////
// Invalid due to param_operand occurences in options dict not being
// one-to-one with the dynamic options provided as params:
// param_operand_index out of bounds w.r.t. the number of options provided via params.
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 1 is out of bounds for the number of dynamic options: 1}}
%2 = "transform.apply_registered_pass"(%1, %0) <{
options = {"max-iterations" = #transform.param_operand<index=1 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
: (!transform.any_param, !transform.any_op) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
// -----
// Invalid due to param_operand occurences in options dict not being
// one-to-one with the dynamic options provided as params:
// the first option-param is referred to twice and the second one not at all.
// (In the pretty-printed format, if you want to refer to a param SSA-value twice, it counts as two param arguments.)
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{dynamic option index 0 is already used in options}}
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"max-num-rewrites" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()
// -----
// Invalid due to param_operand occurences in options dict not being
// one-to-one with the dynamic options provided as params:
// two option-params are provide though only the first one is referred to from the options-dict.
"builtin.module"() ({
"transform.named_sequence"() <{function_type = (!transform.any_op) -> (), sym_name = "__transform_main"}> ({
^bb0(%arg0: !transform.any_op):
%0 = "transform.structured.match"(%arg0) <{ops = ["func.func"]}> : (!transform.any_op) -> !transform.any_op
%1 = "transform.param.constant"() <{value = 10 : i64}> : () -> !transform.any_param
%2 = "transform.param.constant"() <{value = 1 : i64}> : () -> !transform.any_param
// expected-error @below {{a param operand does not have a corresponding param_operand attr in the options dict}}
%3 = "transform.apply_registered_pass"(%1, %2, %0) <{
options = {"max-iterations" = #transform.param_operand<index=0 : i64>,
"test-convergence" = true,
"top-down" = false},
pass_name = "canonicalize"}>
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
"transform.yield"() : () -> ()
}) : () -> ()
}) {transform.with_named_sequence} : () -> ()

View File

@ -254,3 +254,55 @@ def testReplicateOp(module: Module):
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
@run
def testApplyRegisteredPassOp(module: Module):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
)
mod = transform.ApplyRegisteredPassOp(
transform.AnyOpType.get(),
"canonicalize",
mod.result,
options={"top-down": BoolAttr.get(False)},
)
max_iter = transform.param_constant(
transform.AnyParamType.get(),
IntegerAttr.get(IntegerType.get_signless(64), 10),
)
max_rewrites = transform.param_constant(
transform.AnyParamType.get(),
IntegerAttr.get(IntegerType.get_signless(64), 1),
)
transform.apply_registered_pass(
transform.AnyOpType.get(),
"canonicalize",
mod,
options={
"top-down": BoolAttr.get(False),
"max-iterations": max_iter,
"test-convergence": BoolAttr.get(True),
"max-rewrites": max_rewrites,
},
)
transform.YieldOp()
# CHECK-LABEL: TEST: testApplyRegisteredPassOp
# CHECK: transform.sequence
# CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
# CHECK-SAME: with options = {"top-down" = false}
# CHECK-SAME: to {{.*}} : (!transform.any_op) -> !transform.any_op
# CHECK: %[[MAX_ITER:.+]] = transform.param.constant
# CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
# CHECK: %{{.*}} = apply_registered_pass "canonicalize"
# NB: MLIR has sorted the dict lexicographically by key:
# CHECK-SAME: with options = {"max-iterations" = %[[MAX_ITER]],
# CHECK-SAME: "max-rewrites" = %[[MAX_REWRITE]],
# CHECK-SAME: "test-convergence" = true,
# CHECK-SAME: "top-down" = false}
# CHECK-SAME: to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op