[MLIR][Transform][Tune] Introduce transform.tune.alternatives op (#160724)
This op enables expressing uncertainty regarding what should be happening at particular places in transform-dialect schedules. In particular, it enables representing a choice among alternative regions. This choice is resolved through providing a `selected_region` argument. When this argument is provided, the semantics are such that it is valid to rewrite the op through substituting in the selected region -- with the op's interpreted semantics corresponding to exactly this. This op represents another piece of the puzzle w.r.t. a toolkit for expressing autotuning problems with the transform dialect. Note that this goes beyond tuning knobs _on_ transforms, going further by making it tunable which (sequences of) transforms are to be applied.
This commit is contained in:
parent
a33544b83c
commit
f4d18c0ef8
@ -9,6 +9,7 @@
|
||||
#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
|
||||
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
|
||||
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
|
||||
@ -11,10 +11,15 @@
|
||||
|
||||
include "mlir/Dialect/Transform/IR/TransformDialect.td"
|
||||
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/BuiltinAttributes.td"
|
||||
include "mlir/IR/CommonAttrConstraints.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KnobOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def KnobOp : Op<Transform_Dialect, "tune.knob", [
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
@ -52,4 +57,53 @@ def KnobOp : Op<Transform_Dialect, "tune.knob", [
|
||||
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AlternativesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AlternativesOp : Op<Transform_Dialect, "tune.alternatives", [
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface,
|
||||
["getEntrySuccessorOperands", "getSuccessorRegions",
|
||||
"getRegionInvocationBounds"]>,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">,
|
||||
NoRegionArguments
|
||||
]> {
|
||||
let summary = "Represents a choice among its regions, i.e. sub-schedules";
|
||||
|
||||
let description = [{
|
||||
This op represents a choice over which of its regions is to be used.
|
||||
|
||||
When `selected_region` is provided, the semantics are that this op is to be
|
||||
substituted for by the selected region, meaning the region's results become
|
||||
the results of this op. Without a provided `selected_region`, the semantics
|
||||
are that this non-deterministic choice is yet to be resolved -- which in
|
||||
terms of the op's interpreted semantics is a failure.
|
||||
|
||||
The `selected_region` argument is either an `IntegerAttr` or a param holding
|
||||
an `IntegerAttr`, which should provide a valid zero-based index with respect
|
||||
to the number of alternatives, i.e. regions.
|
||||
}];
|
||||
let cppNamespace = [{ mlir::transform::tune }];
|
||||
|
||||
let arguments = (ins Builtin_StringAttr:$name,
|
||||
OptionalAttr<APIntAttr>:$selected_region_attr,
|
||||
Optional<TransformParamTypeInterface>:$selected_region_param);
|
||||
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
|
||||
let regions = (region VariadicRegion<SizedRegion<1>>:$alternatives);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`<` $name `>`
|
||||
(`selected_region` `=` custom<AlternativesOpSelectedRegion>(
|
||||
$selected_region_attr, $selected_region_param)^)?
|
||||
attr-dict-with-keyword
|
||||
(`:` type($selected_region_param)^)?
|
||||
(`->` type($results)^)?
|
||||
regions
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
|
||||
|
||||
@ -6,13 +6,24 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.h"
|
||||
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static ParseResult parseAlternativesOpSelectedRegion(
|
||||
OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
|
||||
std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam);
|
||||
|
||||
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
|
||||
Operation *op,
|
||||
IntegerAttr selectedRegionAttr,
|
||||
Value selectedRegionParam);
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
|
||||
|
||||
@ -57,3 +68,176 @@ LogicalResult transform::tune::KnobOp::verify() {
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AlternativesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseAlternativesOpSelectedRegion(
|
||||
OpAsmParser &parser, IntegerAttr &selectedRegionAttr,
|
||||
std::optional<OpAsmParser::UnresolvedOperand> &selectedRegionParam) {
|
||||
size_t selectedRegionIdx;
|
||||
OptionalParseResult attrParseRes =
|
||||
parser.parseOptionalInteger(selectedRegionIdx);
|
||||
if (attrParseRes.has_value()) {
|
||||
if (failed(*attrParseRes))
|
||||
return failure();
|
||||
|
||||
selectedRegionAttr = parser.getBuilder().getIndexAttr(selectedRegionIdx);
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::UnresolvedOperand param;
|
||||
auto paramParseRes = parser.parseOptionalOperand(param);
|
||||
if (paramParseRes.has_value()) {
|
||||
if (failed(*paramParseRes))
|
||||
return failure();
|
||||
|
||||
selectedRegionParam = param;
|
||||
return success();
|
||||
}
|
||||
|
||||
return parser.emitError(parser.getCurrentLocation())
|
||||
<< "expected either an integer attribute or a transform.param operand";
|
||||
}
|
||||
|
||||
static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
|
||||
Operation *op,
|
||||
IntegerAttr selectedRegionAttr,
|
||||
Value selectedRegionParam) {
|
||||
if (selectedRegionAttr)
|
||||
printer << selectedRegionAttr.getValue();
|
||||
if (selectedRegionParam)
|
||||
printer << selectedRegionParam;
|
||||
}
|
||||
|
||||
OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
|
||||
RegionBranchPoint point) {
|
||||
// No operands will be forwarded to the region(s).
|
||||
return getOperands().slice(0, 0);
|
||||
}
|
||||
|
||||
void transform::tune::AlternativesOp::getSuccessorRegions(
|
||||
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
|
||||
if (point.isParent())
|
||||
if (auto selectedRegionIdx = getSelectedRegionAttr())
|
||||
regions.emplace_back(
|
||||
&getAlternatives()[selectedRegionIdx->getSExtValue()],
|
||||
Block::BlockArgListType());
|
||||
else
|
||||
for (Region &alternative : getAlternatives())
|
||||
regions.emplace_back(&alternative, Block::BlockArgListType());
|
||||
else
|
||||
regions.emplace_back(getOperation()->getResults());
|
||||
}
|
||||
|
||||
void transform::tune::AlternativesOp::getRegionInvocationBounds(
|
||||
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
|
||||
(void)operands;
|
||||
bounds.reserve(getNumRegions());
|
||||
|
||||
if (auto selectedRegionIdx = getSelectedRegionAttr()) {
|
||||
bounds.resize(getNumRegions(), InvocationBounds(0, 0));
|
||||
bounds[selectedRegionIdx->getSExtValue()] = InvocationBounds(1, 1);
|
||||
} else {
|
||||
bounds.resize(getNumRegions(), InvocationBounds(0, 1));
|
||||
}
|
||||
}
|
||||
|
||||
void transform::tune::AlternativesOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getSelectedRegionParamMutable(), effects);
|
||||
producesHandle(getOperation()->getOpResults(), effects);
|
||||
// TODO: should effects from regions be forwarded?
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::tune::AlternativesOp::apply(transform::TransformRewriter &rewriter,
|
||||
transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
std::optional<size_t> selectedRegionIdx;
|
||||
|
||||
if (auto selectedRegionAttr = getSelectedRegionAttr())
|
||||
selectedRegionIdx = selectedRegionAttr->getSExtValue();
|
||||
|
||||
if (Value selectedRegionParam = getSelectedRegionParam()) {
|
||||
ArrayRef<Attribute> associatedAttrs = state.getParams(selectedRegionParam);
|
||||
IntegerAttr selectedRegionAttr;
|
||||
if (associatedAttrs.size() != 1 ||
|
||||
!(selectedRegionAttr = dyn_cast<IntegerAttr>(associatedAttrs[0])))
|
||||
return emitDefiniteFailure()
|
||||
<< "param should hold exactly one integer attribute, got: "
|
||||
<< associatedAttrs[0];
|
||||
selectedRegionIdx = selectedRegionAttr.getValue().getSExtValue();
|
||||
}
|
||||
|
||||
if (!selectedRegionIdx)
|
||||
return emitDefiniteFailure() << "non-deterministic choice " << getName()
|
||||
<< " is only resolved through providing a "
|
||||
"`selected_region` attr/param";
|
||||
|
||||
if (*selectedRegionIdx < 0 || *selectedRegionIdx >= getNumRegions())
|
||||
return emitDefiniteFailure()
|
||||
<< "'selected_region' attribute/param specifies region at index "
|
||||
<< *selectedRegionIdx << " while op has only " << getNumRegions()
|
||||
<< " regions";
|
||||
|
||||
Region &selectedRegion = getRegion(*selectedRegionIdx);
|
||||
auto scope = state.make_region_scope(selectedRegion);
|
||||
Block &block = selectedRegion.front();
|
||||
// Apply the region's ops one by one.
|
||||
for (Operation &transform : block.without_terminator()) {
|
||||
DiagnosedSilenceableFailure result =
|
||||
state.applyTransform(cast<transform::TransformOpInterface>(transform));
|
||||
if (result.isDefiniteFailure())
|
||||
return result;
|
||||
|
||||
if (result.isSilenceableFailure()) {
|
||||
for (const auto &res : getResults())
|
||||
results.set(res, {});
|
||||
return result;
|
||||
}
|
||||
}
|
||||
// Forward the operation mapping for values yielded from the region to the
|
||||
// values produced by the alternatives op.
|
||||
transform::detail::forwardTerminatorOperands(&block, state, results);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
LogicalResult transform::tune::AlternativesOp::verify() {
|
||||
for (auto *region : getRegions()) {
|
||||
auto yieldTerminator =
|
||||
llvm::dyn_cast_if_present<transform::YieldOp>(region->front().back());
|
||||
if (!yieldTerminator)
|
||||
return emitOpError() << "expected '"
|
||||
<< transform::YieldOp::getOperationName()
|
||||
<< "' as terminator";
|
||||
|
||||
if (yieldTerminator->getNumOperands() != getNumResults())
|
||||
return yieldTerminator.emitOpError()
|
||||
<< "expected terminator to have as many operands as the parent op "
|
||||
"has results";
|
||||
|
||||
for (auto [i, operandType, resultType] : llvm::zip_equal(
|
||||
llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
|
||||
yieldTerminator->getOperands().getType(), getResultTypes())) {
|
||||
if (operandType == resultType)
|
||||
continue;
|
||||
return yieldTerminator.emitOpError()
|
||||
<< "the type of the terminator operand #" << i
|
||||
<< " must match the type of the corresponding parent op result ("
|
||||
<< operandType << " vs " << resultType << ")";
|
||||
}
|
||||
}
|
||||
|
||||
if (auto selectedRegionAttr = getSelectedRegionAttr()) {
|
||||
size_t regionIdx = selectedRegionAttr->getSExtValue();
|
||||
if (regionIdx < 0 || regionIdx >= getNumRegions())
|
||||
return emitOpError()
|
||||
<< "'selected_region' attribute specifies region at index "
|
||||
<< regionIdx << " while op has only " << getNumRegions()
|
||||
<< " regions";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -6,6 +6,9 @@ from typing import Optional, Sequence
|
||||
|
||||
from ...ir import (
|
||||
Type,
|
||||
Value,
|
||||
Operation,
|
||||
OpView,
|
||||
Attribute,
|
||||
ArrayAttr,
|
||||
StringAttr,
|
||||
@ -19,7 +22,10 @@ from .._transform_tune_extension_ops_gen import *
|
||||
from .._transform_tune_extension_ops_gen import _Dialect
|
||||
|
||||
try:
|
||||
from .._ods_common import _cext as _ods_cext
|
||||
from .._ods_common import (
|
||||
get_op_result_or_value as _get_op_result_or_value,
|
||||
_cext as _ods_cext,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise RuntimeError("Error loading imports from extension module") from e
|
||||
|
||||
@ -36,7 +42,7 @@ class KnobOp(KnobOp):
|
||||
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
|
||||
],
|
||||
*,
|
||||
selected: Optional[Attribute] = None,
|
||||
selected: Optional[Union[Attribute, bool, int, float, str]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
@ -75,8 +81,62 @@ def knob(
|
||||
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
|
||||
],
|
||||
*,
|
||||
selected: Optional[Attribute] = None,
|
||||
selected: Optional[Union[Attribute, bool, int, float, str]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True)
|
||||
class AlternativesOp(AlternativesOp):
|
||||
def __init__(
|
||||
self,
|
||||
results: Sequence[Type],
|
||||
name: Union[StringAttr, str],
|
||||
num_alternatives: int,
|
||||
*,
|
||||
selected_region: Optional[
|
||||
Union[int, IntegerAttr, Value, Operation, OpView]
|
||||
] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
if isinstance(name, str):
|
||||
name = StringAttr.get(name)
|
||||
|
||||
selected_region_attr = selected_region_param = None
|
||||
if isinstance(selected_region, IntegerAttr):
|
||||
selected_region_attr = selected_region
|
||||
elif isinstance(selected_region, int):
|
||||
selected_region_attr = IntegerAttr.get(
|
||||
IntegerType.get_signless(32), selected_region
|
||||
)
|
||||
elif isinstance(selected_region, (Value, Operation, OpView)):
|
||||
selected_region_param = _get_op_result_or_value(selected_region)
|
||||
|
||||
super().__init__(
|
||||
results,
|
||||
name,
|
||||
num_alternatives,
|
||||
selected_region_attr=selected_region_attr,
|
||||
selected_region_param=selected_region_param,
|
||||
loc=loc,
|
||||
ip=ip,
|
||||
)
|
||||
for region in self.regions:
|
||||
region.blocks.append()
|
||||
|
||||
|
||||
def alternatives(
|
||||
results: Sequence[Type],
|
||||
name: Union[StringAttr, str],
|
||||
num_alternatives: int,
|
||||
*,
|
||||
selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None,
|
||||
loc=None,
|
||||
ip=None,
|
||||
):
|
||||
return AlternativesOp(
|
||||
results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip
|
||||
)
|
||||
|
||||
@ -19,3 +19,88 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @f()
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
// expected-error@below {{'selected_region' attribute specifies region at index 2 while op has only 2 regions}}
|
||||
transform.tune.alternatives<"bifurcation"> selected_region = 2 {
|
||||
transform.yield
|
||||
}, {
|
||||
transform.yield
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @f()
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%singleton_of_c0 = transform.param.constant [0] -> !transform.any_param
|
||||
// expected-error@below {{param should hold exactly one integer attribute, got: [0]}}
|
||||
transform.tune.alternatives<"bifurcation"> selected_region = %singleton_of_c0 : !transform.any_param {
|
||||
transform.yield
|
||||
}, {
|
||||
transform.yield
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @f()
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%c0 = transform.param.constant 0 -> !transform.any_param
|
||||
%c1 = transform.param.constant 1 -> !transform.any_param
|
||||
%c0_and_c1 = transform.merge_handles %c0, %c1 : !transform.any_param
|
||||
// expected-error@below {{param should hold exactly one integer attribute}}
|
||||
transform.tune.alternatives<"bifurcation"> selected_region = %c0_and_c1 : !transform.any_param {
|
||||
transform.yield
|
||||
}, {
|
||||
transform.yield
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @f()
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%c2 = transform.param.constant 2 -> !transform.any_param
|
||||
// expected-error@below {{'selected_region' attribute/param specifies region at index 2 while op has only 2 regions}}
|
||||
transform.tune.alternatives<"bifurcation"> selected_region = %c2 : !transform.any_param {
|
||||
transform.yield
|
||||
}, {
|
||||
transform.yield
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func private @f()
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
// expected-error@below {{non-deterministic choice "bifurcation" is only resolved through providing a `selected_region` attr/param}}
|
||||
transform.tune.alternatives<"bifurcation"> {
|
||||
transform.yield
|
||||
}, {
|
||||
transform.yield
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,3 +59,129 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: schedule_with_two_independent_choices_already_made
|
||||
func.func @schedule_with_two_independent_choices_already_made(
|
||||
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32> {
|
||||
// CHECK-NOT: scf.forall
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: scf.forall
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: scf.forall.in_parallel
|
||||
// CHECK: tensor.parallel_insert_slice
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK: scf.yield
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
|
||||
outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
|
||||
return %0 : tensor<128x128xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
|
||||
%tiled_matmul = transform.tune.alternatives<"outer_par_or_seq_tiling"> selected_region = 0 -> !transform.any_op
|
||||
{ // First alternative/region, with index = 0
|
||||
%contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield %contained_matmul : !transform.any_op
|
||||
}, { // Second alternative/region, with index = 1
|
||||
%contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield %contained_matmul : !transform.any_op
|
||||
}
|
||||
|
||||
transform.tune.alternatives<"inner_par_or_seq_tiling"> selected_region = 1 -> !transform.any_op {
|
||||
%contained_matmul, %loop = transform.structured.tile_using_for %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield %contained_matmul : !transform.any_op
|
||||
}, {
|
||||
%contained_matmul, %loop = transform.structured.tile_using_forall %tiled_matmul tile_sizes [0, 16] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield %contained_matmul : !transform.any_op
|
||||
}
|
||||
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: subschedule_with_choice_resolved_in_main_schedule
|
||||
func.func @subschedule_with_choice_resolved_in_main_schedule(
|
||||
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
|
||||
-> tensor<128x128xf32> {
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: scf.forall
|
||||
// CHECK-NOT: scf.forall
|
||||
// CHECK: scf.for
|
||||
// CHECK-NOT: scf.forall
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: linalg.matmul
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK: scf.yield
|
||||
// CHECK: scf.forall.in_parallel
|
||||
// CHECK: tensor.parallel_insert_slice
|
||||
%0 = linalg.matmul ins(%arg0, %arg1: tensor<128x128xf32>, tensor<128x128xf32>)
|
||||
outs(%arg2: tensor<128x128xf32>) -> tensor<128x128xf32>
|
||||
return %0 : tensor<128x128xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @subschedule_with_embedded_choice(%matmul: !transform.any_op {transform.readonly},
|
||||
%par_or_seq: !transform.param<i64> {transform.readonly},
|
||||
%tile_size: !transform.param<i64> {transform.readonly}) -> !transform.any_op {
|
||||
%tiled_matmul = transform.tune.alternatives<"par_or_seq_tiling"> selected_region = %par_or_seq : !transform.param<i64> -> !transform.any_op {
|
||||
%contained_matmul, %loop = transform.structured.tile_using_for %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield %contained_matmul : !transform.any_op
|
||||
}, {
|
||||
%contained_matmul, %loop = transform.structured.tile_using_forall %matmul tile_sizes [%tile_size] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield %contained_matmul : !transform.any_op
|
||||
}
|
||||
transform.yield %tiled_matmul : !transform.any_op
|
||||
}
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%outer_par = transform.param.constant 1 -> !transform.param<i64>
|
||||
%outer_tile_size = transform.param.constant 32 -> !transform.param<i64>
|
||||
%inner_seq = transform.tune.knob<"inner_par_or_seq"> = 0 from options = [0, 1] -> !transform.param<i64>
|
||||
%inner_tile_size = transform.param.constant 8 -> !transform.param<i64>
|
||||
%tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%matmul, %outer_par, %outer_tile_size) : (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op
|
||||
%tiled_tiled_matmul = transform.include @subschedule_with_embedded_choice failures(propagate) (%tiled_matmul, %inner_seq, %inner_tile_size) : (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: eeny_meeny_miny_moe
|
||||
func.func private @eeny_meeny_miny_moe()
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
|
||||
%tiled_matmul = transform.tune.alternatives<"4way"> selected_region = 3 -> !transform.any_param
|
||||
{ // First alternative/region, with index = 0
|
||||
%out = transform.param.constant "eeny" -> !transform.any_param
|
||||
transform.yield %out : !transform.any_param
|
||||
}, { // Second alternative/region, with index = 1
|
||||
%out = transform.param.constant "meeny" -> !transform.any_param
|
||||
transform.yield %out : !transform.any_param
|
||||
}, { // Third alternative/region, with index = 2
|
||||
%out = transform.param.constant "miny" -> !transform.any_param
|
||||
transform.yield %out : !transform.any_param
|
||||
}, { // Fourth alternative/region, with index = 3
|
||||
%out = transform.param.constant "moe" -> !transform.any_param
|
||||
transform.yield %out : !transform.any_param
|
||||
}
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
@ -1,21 +1,21 @@
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
from mlir.ir import *
|
||||
from mlir import ir
|
||||
from mlir.dialects import transform
|
||||
from mlir.dialects.transform import tune, debug
|
||||
|
||||
|
||||
def run(f):
|
||||
print("\nTEST:", f.__name__)
|
||||
with Context(), Location.unknown():
|
||||
module = Module.create()
|
||||
with InsertionPoint(module.body):
|
||||
print("\n// TEST:", f.__name__)
|
||||
with ir.Context(), ir.Location.unknown():
|
||||
module = ir.Module.create()
|
||||
with ir.InsertionPoint(module.body):
|
||||
sequence = transform.SequenceOp(
|
||||
transform.FailurePropagationMode.Propagate,
|
||||
[],
|
||||
transform.AnyOpType.get(),
|
||||
)
|
||||
with InsertionPoint(sequence.body):
|
||||
with ir.InsertionPoint(sequence.body):
|
||||
f(sequence.bodyTarget)
|
||||
transform.YieldOp()
|
||||
print(module)
|
||||
@ -29,10 +29,10 @@ def testKnobOp(target):
|
||||
|
||||
# CHECK: %[[HEADS_OR_TAILS:.*]] = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
|
||||
heads_or_tails = tune.KnobOp(
|
||||
result=any_param, name=StringAttr.get("coin"), options=[True, False]
|
||||
result=any_param, name=ir.StringAttr.get("coin"), options=[True, False]
|
||||
)
|
||||
# CHECK: transform.tune.knob<"animal"> options = ["cat", "dog", unit] -> !transform.any_param
|
||||
tune.KnobOp(any_param, name="animal", options=["cat", "dog", UnitAttr.get()])
|
||||
tune.KnobOp(any_param, name="animal", options=["cat", "dog", ir.UnitAttr.get()])
|
||||
# CHECK: transform.tune.knob<"tile_size"> options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
|
||||
tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32])
|
||||
# CHECK: transform.tune.knob<"magic_value"> options = [2.000000e+00, 2.250000e+00, 2.500000e+00, 2.750000e+00, 3.000000e+00] -> !transform.any_param
|
||||
@ -45,7 +45,10 @@ def testKnobOp(target):
|
||||
heads = tune.KnobOp(any_param, "coin", options=[True, False], selected=True)
|
||||
# CHECK: transform.tune.knob<"animal"> = "dog" from options = ["cat", "dog", unit] -> !transform.any_param
|
||||
tune.KnobOp(
|
||||
any_param, name="animal", options=["cat", "dog", UnitAttr.get()], selected="dog"
|
||||
any_param,
|
||||
name="animal",
|
||||
options=["cat", "dog", ir.UnitAttr.get()],
|
||||
selected="dog",
|
||||
)
|
||||
# CHECK: transform.tune.knob<"tile_size"> = 8 : i64 from options = [2, 4, 8, 16, 24, 32] -> !transform.any_param
|
||||
tune.KnobOp(any_param, "tile_size", [2, 4, 8, 16, 24, 32], selected=8)
|
||||
@ -57,16 +60,90 @@ def testKnobOp(target):
|
||||
|
||||
# CHECK: transform.tune.knob<"range_as_a_dict"> = 4 : i64 from options = {start = 2 : i64, step = 2 : i64, stop = 16 : i64} -> !transform.any_param
|
||||
# NB: Membership of `selected` in non-ArrayAttr `options` is _not_ verified.
|
||||
i64 = IntegerType.get_signless(64)
|
||||
i64 = ir.IntegerType.get_signless(64)
|
||||
tune.knob(
|
||||
any_param,
|
||||
"range_as_a_dict",
|
||||
DictAttr.get(
|
||||
ir.DictAttr.get(
|
||||
{
|
||||
"start": IntegerAttr.get(i64, 2),
|
||||
"stop": IntegerAttr.get(i64, 16),
|
||||
"step": IntegerAttr.get(i64, 2),
|
||||
"start": ir.IntegerAttr.get(i64, 2),
|
||||
"stop": ir.IntegerAttr.get(i64, 16),
|
||||
"step": ir.IntegerAttr.get(i64, 2),
|
||||
}
|
||||
),
|
||||
selected=4,
|
||||
)
|
||||
|
||||
|
||||
# CHECK-LABEL: TEST: testAlternativesOp
|
||||
@run
|
||||
def testAlternativesOp(target):
|
||||
any_param = transform.AnyParamType.get()
|
||||
|
||||
# CHECK: %[[LEFT_OR_RIGHT_OUTCOME:.*]] = transform.tune.alternatives<"left_or_right"> -> !transform.any_param {
|
||||
left_or_right = tune.AlternativesOp(
|
||||
[transform.AnyParamType.get()], "left_or_right", 2
|
||||
)
|
||||
idx_for_left, idx_for_right = 0, 1
|
||||
with ir.InsertionPoint(left_or_right.alternatives[idx_for_left].blocks[0]):
|
||||
# CHECK: %[[C0:.*]] = transform.param.constant 0
|
||||
i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
|
||||
c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
|
||||
# CHECK: transform.yield %[[C0]]
|
||||
transform.yield_(c0)
|
||||
# CHECK-NEXT: }, {
|
||||
with ir.InsertionPoint(left_or_right.alternatives[idx_for_right].blocks[0]):
|
||||
# CHECK: %[[C1:.*]] = transform.param.constant 1
|
||||
i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
|
||||
c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
|
||||
# CHECK: transform.yield %[[C1]]
|
||||
transform.yield_(c1)
|
||||
# CHECK-NEXT: }
|
||||
outcome_of_left_or_right_decision = left_or_right.results[0]
|
||||
|
||||
# CHECK: transform.tune.alternatives<"fork_in_the_road"> selected_region = 0 -> !transform.any_param {
|
||||
fork_in_the_road = tune.AlternativesOp(
|
||||
[transform.AnyParamType.get()], "fork_in_the_road", 2, selected_region=0
|
||||
)
|
||||
with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_left].blocks[0]):
|
||||
# CHECK: %[[C0:.*]] = transform.param.constant 0
|
||||
i32_0 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)
|
||||
c0 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_0)
|
||||
# CHECK: transform.yield %[[C0]]
|
||||
transform.yield_(c0)
|
||||
# CHECK-NEXT: }, {
|
||||
with ir.InsertionPoint(fork_in_the_road.alternatives[idx_for_right].blocks[0]):
|
||||
# CHECK: %[[C1:.*]] = transform.param.constant 1
|
||||
i32_1 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1)
|
||||
c1 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1)
|
||||
# CHECK: transform.yield %[[C1]]
|
||||
transform.yield_(c1)
|
||||
# CHECK-NEXT: }
|
||||
|
||||
# CHECK: transform.tune.alternatives<"left_or_right_as_before"> selected_region = %[[LEFT_OR_RIGHT_OUTCOME]] : !transform.any_param {
|
||||
left_or_right_as_before = tune.AlternativesOp(
|
||||
[],
|
||||
"left_or_right_as_before",
|
||||
2,
|
||||
selected_region=outcome_of_left_or_right_decision,
|
||||
)
|
||||
with ir.InsertionPoint(
|
||||
left_or_right_as_before.alternatives[idx_for_left].blocks[0]
|
||||
):
|
||||
# CHECK: transform.param.constant 1337
|
||||
i32_1337 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 1337)
|
||||
c1337 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_1337)
|
||||
# CHECK: transform.debug.emit_param_as_remark
|
||||
debug.emit_param_as_remark(c1337)
|
||||
transform.yield_([])
|
||||
# CHECK-NEXT: }, {
|
||||
with ir.InsertionPoint(
|
||||
left_or_right_as_before.alternatives[idx_for_right].blocks[0]
|
||||
):
|
||||
# CHECK: transform.param.constant 42
|
||||
i32_42 = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
|
||||
c42 = transform.ParamConstantOp(transform.AnyParamType.get(), i32_42)
|
||||
# CHECK: transform.debug.emit_param_as_remark
|
||||
debug.emit_param_as_remark(c42)
|
||||
transform.yield_([])
|
||||
# CHECK-NEXT: }
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user