llvm-project/mlir/test/python/dialects/transform_vector_ext.py
Erick Ochoa Lopez 9d19250610
[mlir][vector] Add vector.to_elements unrolling (#157142)
This PR adds support for unrolling `vector.to_element`'s source operand.

It transforms

```mlir
%0:8 = vector.to_elements %v : vector<2x2x2xf32>
```

to

```mlir
%v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32>
%v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32>
%0:4 = vector.to_elements %v0 : vector<2x2xf32>
%1:4 = vector.to_elements %v1 : vector<2x2xf32>
// %0:8 = %0:4 - %1:4
```

This pattern will be applied until there are only 1-D vectors left.

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
Co-authored-by: hanhanW <hanhan0912@gmail.com>
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
2025-09-11 13:56:57 -04:00

152 lines
6.4 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import vector
def run_apply_patterns(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
with InsertionPoint(apply.patterns):
f()
transform.YieldOp()
print("\nTEST:", f.__name__)
print(module)
return f
@run_apply_patterns
def non_configurable_patterns():
# CHECK-LABEL: TEST: non_configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
# CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
vector.ApplyRankReducingSubviewPatternsOp()
# CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
vector.ApplyTransferPermutationPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_broadcast
vector.ApplyLowerBroadcastPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masks
vector.ApplyLowerMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masked_transfers
vector.ApplyLowerMaskedTransfersPatternsOp()
# CHECK: transform.apply_patterns.vector.materialize_masks
vector.ApplyMaterializeMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_outerproduct
vector.ApplyLowerOuterProductPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_gather
vector.ApplyLowerGatherPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_from_elements
vector.ApplyUnrollFromElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_to_elements
vector.ApplyUnrollToElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
vector.ApplyLowerShapeCastPatternsOp()
@run_apply_patterns
def configurable_patterns():
# CHECK-LABEL: TEST: configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.lower_transfer
# CHECK-SAME: max_transfer_rank = 4
vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
# CHECK: transform.apply_patterns.vector.transfer_to_scf
# CHECK-SAME: max_transfer_rank = 3
# CHECK-SAME: full_unroll = true
vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
@run_apply_patterns
def enum_configurable_patterns():
# CHECK: transform.apply_patterns.vector.lower_contraction
vector.ApplyLowerContractionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = matmulintrinsics
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.Matmul
)
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = parallelarith
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.ParallelArith
)
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
vector.ApplyLowerMultiReductionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
# This is the default mode, not printed.
vector.ApplyLowerMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
)
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyLowerMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
# CHECK: transform.apply_patterns.vector.lower_transpose
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
# This is the default strategy, not printed.
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.EltWise
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = flat_transpose
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Flat
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_1d
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_16x16
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = flat_transpose
# CHECK-SAME: avx2_lowering_strategy = true
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Flat,
avx2_lowering_strategy=True,
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
vector.ApplySplitTransferFullPartialPatternsOp()
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = none
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.None_
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# This is the default mode, not printed.
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
)