llvm-project/mlir/test/python/dialects/transform_structured_ext.py
Ingo Müller 9442b441c1 [mlir][linalg][transform][python] Fix optional args of PadOp mix-in.
The mix-in did not allow to *not* set many of the arguments, even though
they represent optional attributes. Instead, it set default values,
which have different semantics in some cases. In other cases, setting
the default values is already done by the C++ layer, in which case they
are currently redundant and may be wrong in some potential future change
in the TD or C++ files. With this patch, `None` is preserved until the
generated binding, which handles them as desired.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158844
2023-09-02 11:19:06 +00:00

675 lines
24 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects import pdl
from mlir.dialects.transform import structured
from mlir.dialects.transform import pdl as transform_pdl
def run(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
module.operation.verify()
print(module)
return f
@run
def testBufferizeToAllocationOpCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.BufferizeToAllocationOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testBufferizeToAllocationOpCompact
# CHECK: transform.sequence
# CHECK: transform.structured.bufferize_to_allocation
@run
def testBufferizeToAllocationOpArgs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.BufferizeToAllocationOp(
sequence.bodyTarget,
memory_space=3,
memcpy_op="memref.copy",
alloc_op="memref.alloca",
bufferize_destination_only=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testBufferizeToAllocationOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.bufferize_to_allocation
# CHECK-SAME: alloc_op = "memref.alloca"
# CHECK-SAME: bufferize_destination_only
# CHECK-SAME: memcpy_op = "memref.copy"
# CHECK-SAME: memory_space = 3
@run
def testDecompose():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.DecomposeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testDecompose
# CHECK: transform.sequence
# CHECK: transform.structured.decompose
@run
def testFuseIntoContainingOpTypes():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
containing = structured.MatchOp.match_op_names(
sequence.bodyTarget, ["test.dummy"]
)
structured.FuseIntoContainingOp(
transform.OperationType.get("test.dummy"),
transform.OperationType.get("test.dummy"),
fused,
containing,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testFuseIntoContainingOpTypes
# CHECK: = transform.structured.fuse_into_containing_op
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.op<"test.dummy">, !transform.op<"test.dummy">)
@run
def testFuseIntoContainingOpCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
fused = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
containing = structured.MatchOp.match_op_names(
sequence.bodyTarget, ["test.dummy"]
)
structured.FuseIntoContainingOp(fused, containing)
transform.YieldOp()
# CHECK-LABEL: TEST: testFuseIntoContainingOpCompact
# CHECK: = transform.structured.fuse_into_containing_op
# CHECK-SAME: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
@run
def testGeneralize():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.GeneralizeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testGeneralize
# CHECK: transform.sequence
# CHECK: transform.structured.generalize
@run
def testInterchange():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.InterchangeOp(sequence.bodyTarget, iterator_interchange=[1, 0])
transform.YieldOp()
# CHECK-LABEL: TEST: testInterchange
# CHECK: transform.sequence
# CHECK: transform.structured.interchange
# CHECK: iterator_interchange = [1, 0]
@run
def testMapCopyToThreadsOpCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.MapCopyToThreadsOp(
sequence.bodyTarget, total_num_threads=32, desired_bit_alignment=128
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMapCopyToThreadsOpCompact
# CHECK: = transform.structured.gpu.map_copy_to_threads
# CHECK-SAME: total_num_threads = 32
# CHECK-SAME: desired_bit_alignment = 128
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
@run
def testMapCopyToThreadsOpTypes():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.MapCopyToThreadsOp(
transform.OperationType.get("test.opA"),
transform.OperationType.get("test.opB"),
sequence.bodyTarget,
total_num_threads=32,
desired_bit_alignment=128,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMapCopyToThreadsOpTypes
# CHECK: = transform.structured.gpu.map_copy_to_threads
# CHECK-SAME: total_num_threads = 32
# CHECK-SAME: desired_bit_alignment = 128
# CHECK-SAME: (!transform.any_op) -> (!transform.op<"test.opA">, !transform.op<"test.opB">)
@run
def testMatchOpNamesString():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.MatchOp.match_op_names(sequence.bodyTarget, "test.dummy")
transform.YieldOp()
# CHECK-LABEL: TEST: testMatchOpNamesString
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
@run
def testMatchOpNamesList():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
transform.YieldOp()
# CHECK-LABEL: TEST: testMatchOpNamesList
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
# CHECK-SAME: (!transform.any_op) -> !transform.any_op
@run
def testMaskedVectorizeStatic():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.MaskedVectorizeOp(sequence.bodyTarget, [16, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testMaskedVectorizeStatic
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
# CHECK-SAME: vector_sizes [16, 4]
@run
def testMaskedVectorizeArray():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
sizes = Attribute.parse("[16, 4]")
structured.MaskedVectorizeOp(sequence.bodyTarget, sizes)
transform.YieldOp()
# CHECK-LABEL: TEST: testMaskedVectorizeArray
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
# CHECK-SAME: vector_sizes [16, 4]
@run
def testMaskedVectorizeMixed():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
sz2 = Attribute.parse("4")
structured.MaskedVectorizeOp(sequence.bodyTarget, [sz1, sz2])
transform.YieldOp()
# CHECK-LABEL: TEST: testMaskedVectorizeMixed
# CHECK: transform.sequence
# CHECK: %[[V0:.*]] = transform.structured.match
# CHECK: transform.structured.masked_vectorize
# CHECK-SAME: vector_sizes [%[[V0]] : !transform.any_op, 4]
@run
def testMaskedVectorizeScalable():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
sz1 = structured.MatchOp.match_op_names(sequence.bodyTarget, ["arith.constant"])
sz2 = Attribute.parse("4")
structured.MaskedVectorizeOp(sequence.bodyTarget, [16, [sz1], [sz2], [8]])
transform.YieldOp()
# CHECK-LABEL: TEST: testMaskedVectorizeScalable
# CHECK: transform.sequence
# CHECK-DAG: %[[V0:.*]] = transform.structured.match
# CHECK-DAG: transform.structured.masked_vectorize
# CHECK-SAME: vector_sizes [16, [%[[V0]] : !transform.any_op], [4], [8]]
@run
def testMaskedVectorizeArgs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.MaskedVectorizeOp(
sequence.bodyTarget, [16, 4], vectorize_nd_extract=True
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMaskedVectorizeArgs
# CHECK: transform.sequence
# CHECK: transform.structured.masked_vectorize
# CHECK-SAME: vectorize_nd_extract
@run
def testMatchOpNamesTyped():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.MatchOp.match_op_names(
transform.OperationType.get("test.dummy"),
sequence.bodyTarget,
["test.dummy"],
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMatchOpNamesTyped
# CHECK: transform.structured.match ops
# CHECK-SAME: ["test.dummy"]
# CHECK-SAME: (!transform.any_op) -> !transform.op<"test.dummy">
@run
def testMultitileSizes():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.MultiTileSizesOp(
pdl.OperationType.get(), sequence.bodyTarget, dimension=1, target_size=42
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMultitileSizes
# CHECK: transform.sequence
# CHECK: transform.structured.multitile_sizes
# CHECK-DAG: dimension = 1
# CHECK-DAG: target_size = 42
@run
def testPadOpNoArgs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.PadOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testPadOpNoArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
# CHECK-NOT: copy_back_op
# CHECK-NOT: pack_paddings
# CHECK-NOT: pad_to_multiple_of
# CHECK-NOT: padding_dimensions
# CHECK-NOT: padding_values
# CHECK-NOT: transpose_paddings
@run
def testPadOpArgs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.PadOp(
sequence.bodyTarget,
padding_values=[FloatAttr.get_f32(42.0), StringAttr.get("0")],
padding_dimensions=Attribute.parse("[1]"),
pad_to_multiple_of=[128],
pack_paddings=[0],
transpose_paddings=[[1, Attribute.parse("0")], Attribute.parse("[0, 1]")],
copy_back_op="linalg.copy",
)
transform.YieldOp()
# CHECK-LABEL: TEST: testPadOpArgs
# CHECK: transform.sequence
# CHECK: transform.structured.pad
# CHECK-DAG: copy_back_op = "linalg.copy"
# CHECK-DAG: pack_paddings = [0]
# CHECK-DAG: pad_to_multiple_of = [128]
# CHECK-DAG: padding_dimensions = [1]
# CHECK-DAG: padding_values = [4.200000e+01 : f32]
# CHECK-DAG: transpose_paddings = {{\[}}[1, 0], [0, 1]]
@run
def testScalarize():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.ScalarizeOp(sequence.bodyTarget)
transform.YieldOp()
# CHECK-LABEL: TEST: testScalarize
# CHECK: transform.structured.scalarize
@run
def testSplit():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
structured.SplitOp(split.results[0], dimension=3, split_point=split.results[1])
transform.YieldOp()
# CHECK-LABEL: TEST: testSplit
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
@run
def testTileCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=[4, 8], interchange=[0, 1])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileCompact
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
# CHECK: interchange = [0, 1]
@run
def testTileAttributes():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
with InsertionPoint(sequence.body):
structured.TileOp(sequence.bodyTarget, sizes=attr, interchange=ichange)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileAttributes
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 8]
# CHECK: interchange = [0, 1]
@run
def testTileZero():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.TileOp(
sequence.bodyTarget, sizes=[4, 0, 2, 0], interchange=[0, 1, 2, 3]
)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileZero
# CHECK: transform.sequence
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.tile %{{.*}}[4, 0, 2, 0]
# CHECK: interchange = [0, 1, 2, 3]
@run
def testTileDynamic():
with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get())
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], with_pdl.bodyTarget
)
with InsertionPoint(sequence.body):
m1 = transform_pdl.PDLMatchOp(
pdl.OperationType.get(), sequence.bodyTarget, "first"
)
m2 = transform_pdl.PDLMatchOp(
pdl.OperationType.get(), sequence.bodyTarget, "second"
)
structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileDynamic
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.+}}, %{{.+}}:3 = transform.structured.tile %{{.*}}[%[[FIRST]], 3, %[[SECOND]], 0]
@run
def testTileExplicitLoopTypeSingle():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.TileOp(
transform.OperationType.get("scf.for"), sequence.bodyTarget, sizes=[2, 3, 4]
)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileExplicitLoopTypeSingle
# CHECK: = transform.structured.tile %{{.*}} : (!{{.*}}) ->
# CHECK-COUNT-3: !transform.op<"scf.for">
@run
def testTileExplicitLoopTypeAll():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
types = [
transform.OperationType.get(x)
for x in ["scf.for", "scf.parallel", "scf.forall"]
]
with InsertionPoint(sequence.body):
structured.TileOp(types, sequence.bodyTarget, sizes=[2, 3, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileExplicitLoopTypeAll
# CHECK: = transform.structured.tile
# CHECK-SAME : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">,
# CHECK-SAME: !transform.op<"scf.parallel">, !transform.op<"scf.forall">
@run
def testTileToForallCompact():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.OperationType.get("linalg.matmul"),
)
with InsertionPoint(sequence.body):
structured.TileToForallOp(sequence.bodyTarget, num_threads=[2, 3, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileToForallCompact
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
# CHECK-SAME: (!transform.op<"linalg.matmul">) -> (!transform.any_op, !transform.any_op)
@run
def testTileToForallLoopsAndTileOpTypes():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.TileToForallOp(
transform.OperationType.get("scf.forall"), # loops_type
transform.OperationType.get("linalg.matmul"), # tiled_op_type
sequence.bodyTarget,
num_threads=[2, 3, 4],
)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileToForallLoopsAndTileOpTypes
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads [2, 3, 4] tile_sizes []
# CHECK-SAME: (!transform.any_op) -> (!transform.op<"scf.forall">, !transform.op<"linalg.matmul">)
@run
def testTileToForallTileSizes():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
structured.TileToForallOp(sequence.bodyTarget, tile_sizes=[2, 3, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileToForallTileSizes
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads [] tile_sizes [2, 3, 4]
@run
def testTileToForallMixedDynamic():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
structured.TileToForallOp(sequence.bodyTarget, num_threads=[n, 3, 4])
transform.YieldOp()
# CHECK-LABEL: TEST: testTileToForallMixedDynamic
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads [%{{.*}} : !transform.any_op, 3, 4]
@run
def testTileToForallPackedDynamic():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
n = structured.MatchOp.match_op_names(sequence.bodyTarget, ["test.dummy"])
structured.TileToForallOp(sequence.bodyTarget, num_threads=n)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileToForallPackedDynamic
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: num_threads *(%0 : !transform.any_op)
@run
def testTileToForallMapping():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
mapping = Attribute.parse("[ #gpu.thread<y>, #gpu.thread<x> ]")
structured.TileToForallOp(
sequence.bodyTarget, num_threads=[2, 3], mapping=mapping
)
transform.YieldOp()
# CHECK-LABEL: TEST: testTileToForallMapping
# CHECK: = transform.structured.tile_to_forall_op
# CHECK-SAME: mapping = [#gpu.thread<y>, #gpu.thread<x>]
@run
def testVectorizeAllAttrs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.VectorizeOp(
sequence.bodyTarget,
disable_multi_reduction_to_contract_patterns=True,
disable_transfer_permutation_map_lowering_patterns=True,
vectorize_nd_extract=True,
vectorize_padding=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testVectorizeAllAttrs
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK-SAME: disable_multi_reduction_to_contract_patterns
# CHECK-SAME: disable_transfer_permutation_map_lowering_patterns
# CHECK-SAME: vectorize_nd_extract
# CHECK-SAME: vectorize_padding
@run
def testVectorizeNoAttrs():
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], pdl.OperationType.get()
)
with InsertionPoint(sequence.body):
structured.VectorizeOp(
sequence.bodyTarget,
disable_multi_reduction_to_contract_patterns=False,
disable_transfer_permutation_map_lowering_patterns=False,
vectorize_nd_extract=False,
vectorize_padding=False,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testVectorizeNoAttrs
# CHECK: transform.sequence
# CHECK: = transform.structured.vectorize
# CHECK-NOT: disable_multi_reduction_to_contract_patterns
# CHECK-NOT: disable_transfer_permutation_map_lowering_patterns
# CHECK-NOT: vectorize_nd_extract
# CHECK-NOT: vectorize_padding
@run
def testMatchInterfaceEnum():
names = ArrayAttr.get([StringAttr.get("test.dummy")])
result_type = transform.AnyOpType.get()
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
fused = structured.MatchOp.__base__(
result_type,
sequence.bodyTarget,
ops=names,
interface=structured.MatchInterfaceEnum.LinalgOp,
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMatchInterfaceEnum
# CHECK: transform.sequence
# CHECK: = transform.structured.match
# CHECK: interface{LinalgOp}
@run
def testMatchInterfaceEnumReplaceAttributeBuilder():
@register_attribute_builder("MatchInterfaceEnum", replace=True)
def match_interface_enum(x, context):
if x == "LinalgOp":
y = 0
elif x == "TilingInterface":
y = 1
return IntegerAttr.get(IntegerType.get_signless(32, context=context), y)
names = ArrayAttr.get([StringAttr.get("test.dummy")])
result_type = transform.AnyOpType.get()
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
)
with InsertionPoint(sequence.body):
fused = structured.MatchOp.__base__(
result_type,
sequence.bodyTarget,
ops=names,
interface="TilingInterface",
)
transform.YieldOp()
# CHECK-LABEL: TEST: testMatchInterfaceEnumReplaceAttributeBuilder
# CHECK: transform.sequence
# CHECK: = transform.structured.match
# CHECK: interface{TilingInterface}