llvm-project/mlir/test/python/dialects/transform_vector_ext.py
Andrzej Warzyński 3a1111d668
[mlir][vector] Refine Vector to LLVM lowering options (#159553)
This is a follow-up to https://github.com/llvm/llvm-project/pull/144307,
where we removed `vector.matrix_multiply` and `vector.flat_transpose`
from the Vector dialect.

This PR:
* Updates comments that were missed in the previous change.
* Renames relevant `-convert-vector-to-llvm=` options:
  - `vector-contract-lowering=matmul` → `vector-contract-lowering=llvmintr`
  - `vector-transpose-lowering=flat_transpose` → `vector-transpose-lowering=llvmintr`

These new names better reflect the actual transformation target - LLVM
intrinsics - rather than the now-removed abstract operations.
2025-09-23 16:29:47 +01:00

152 lines
6.4 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import vector
def run_apply_patterns(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
with InsertionPoint(apply.patterns):
f()
transform.YieldOp()
print("\nTEST:", f.__name__)
print(module)
return f
@run_apply_patterns
def non_configurable_patterns():
# CHECK-LABEL: TEST: non_configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
# CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
vector.ApplyRankReducingSubviewPatternsOp()
# CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
vector.ApplyTransferPermutationPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_broadcast
vector.ApplyLowerBroadcastPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masks
vector.ApplyLowerMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_masked_transfers
vector.ApplyLowerMaskedTransfersPatternsOp()
# CHECK: transform.apply_patterns.vector.materialize_masks
vector.ApplyMaterializeMasksPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_outerproduct
vector.ApplyLowerOuterProductPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_gather
vector.ApplyLowerGatherPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_from_elements
vector.ApplyUnrollFromElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.unroll_to_elements
vector.ApplyUnrollToElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
vector.ApplyLowerShapeCastPatternsOp()
@run_apply_patterns
def configurable_patterns():
# CHECK-LABEL: TEST: configurable_patterns
# CHECK: apply_patterns
# CHECK: transform.apply_patterns.vector.lower_transfer
# CHECK-SAME: max_transfer_rank = 4
vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
# CHECK: transform.apply_patterns.vector.transfer_to_scf
# CHECK-SAME: max_transfer_rank = 3
# CHECK-SAME: full_unroll = true
vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
@run_apply_patterns
def enum_configurable_patterns():
# CHECK: transform.apply_patterns.vector.lower_contraction
vector.ApplyLowerContractionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = llvmintr
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.LLVMIntr
)
# CHECK: transform.apply_patterns.vector.lower_contraction
# CHECK-SAME: lowering_strategy = parallelarith
vector.ApplyLowerContractionPatternsOp(
lowering_strategy=vector.VectorContractLowering.ParallelArith
)
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
vector.ApplyLowerMultiReductionPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
# This is the default mode, not printed.
vector.ApplyLowerMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
)
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
# CHECK-SAME: lowering_strategy = innerreduction
vector.ApplyLowerMultiReductionPatternsOp(
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
# CHECK: transform.apply_patterns.vector.lower_transpose
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
# This is the default strategy, not printed.
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.EltWise
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = llvmintr
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.LLVMIntr
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_1d
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = shuffle_16x16
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
)
# CHECK: transform.apply_patterns.vector.lower_transpose
# CHECK-SAME: lowering_strategy = llvmintr
# CHECK-SAME: avx2_lowering_strategy = true
vector.ApplyLowerTransposePatternsOp(
lowering_strategy=vector.VectorTransposeLowering.LLVMIntr,
avx2_lowering_strategy=True,
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
vector.ApplySplitTransferFullPartialPatternsOp()
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = none
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.None_
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# This is the default mode, not printed.
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
)
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
vector.ApplySplitTransferFullPartialPatternsOp(
split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
)