This patch introduces a new unrolling-based approach for lowering multi-dimensional `vector.from_elements` operations. **Implementation Details:** 1. **New Transform Pattern**: Added `UnrollFromElements` that unrolls a N-D(N>=2) from_elements op to a (N-1)-D from_elements op align the outermost dimension. 2. **Utility Functions**: Added `unrollVectorOp` to reuse the unroll algo of vector.gather for vector.from_elements. 3. **Integration**: Added the unrolling pattern to the convert-vector-to-llvm pass as a temporal transformation. 4. Use direct LLVM dialect operations instead of intermediate vector.insert operations for efficiency in `VectorFromElementsLowering`. **Example:** ```mlir // unroll %v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32> => %poison_2d = ub.poison : vector<2x2xf32> %vec_1d_0 = vector.from_elements %e0, %e1 : vector<2xf32> %vec_2d_0 = vector.insert %vec_1d_0, %poison_2d [0] : vector<2xf32> into vector<2x2xf32> %vec_1d_1 = vector.from_elements %e2, %e3 : vector<2xf32> %result = vector.insert %vec_1d_1, %vec_2d_0 [1] : vector<2xf32> into vector<2x2xf32> // convert-vector-to-llvm %v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32> => %poison_2d = ub.poison : vector<2x2xf32> %poison_2d_cast = builtin.unrealized_conversion_cast %poison_2d : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> %poison_1d_0 = llvm.mlir.poison : vector<2xf32> %c0_0 = llvm.mlir.constant(0 : i64) : i64 %vec_1d_0_0 = llvm.insertelement %e0, %poison_1d_0[%c0_0 : i64] : vector<2xf32> %c1_0 = llvm.mlir.constant(1 : i64) : i64 %vec_1d_0_1 = llvm.insertelement %e1, %vec_1d_0_0[%c1_0 : i64] : vector<2xf32> %vec_2d_0 = llvm.insertvalue %vec_1d_0_1, %poison_2d_cast[0] : !llvm.array<2 x vector<2xf32>> %poison_1d_1 = llvm.mlir.poison : vector<2xf32> %c0_1 = llvm.mlir.constant(0 : i64) : i64 %vec_1d_1_0 = llvm.insertelement %e2, %poison_1d_1[%c0_1 : i64] : vector<2xf32> %c1_1 = llvm.mlir.constant(1 : i64) : i64 %vec_1d_1_1 = llvm.insertelement %e3, %vec_1d_1_0[%c1_1 : i64] : vector<2xf32> %vec_2d_1 = llvm.insertvalue %vec_1d_1_1, %vec_2d_0[1] : !llvm.array<2 x vector<2xf32>> %result = builtin.unrealized_conversion_cast %vec_2d_1 : !llvm.array<2 x vector<2xf32>> to vector<2x2xf32> ``` --------- Co-authored-by: Nicolas Vasilache <Nico.Vasilache@amd.com> Co-authored-by: Yang Bai <yangb@nvidia.com> Co-authored-by: James Newling <james.newling@gmail.com> Co-authored-by: Diego Caballero <dieg0ca6aller0@gmail.com>
150 lines
6.3 KiB
Python
150 lines
6.3 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir.ir import *
|
|
from mlir.dialects import transform
|
|
from mlir.dialects.transform import vector
|
|
|
|
|
|
def run_apply_patterns(f):
|
|
with Context(), Location.unknown():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
sequence = transform.SequenceOp(
|
|
transform.FailurePropagationMode.Propagate,
|
|
[],
|
|
transform.AnyOpType.get(),
|
|
)
|
|
with InsertionPoint(sequence.body):
|
|
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
|
|
with InsertionPoint(apply.patterns):
|
|
f()
|
|
transform.YieldOp()
|
|
print("\nTEST:", f.__name__)
|
|
print(module)
|
|
return f
|
|
|
|
|
|
@run_apply_patterns
|
|
def non_configurable_patterns():
|
|
# CHECK-LABEL: TEST: non_configurable_patterns
|
|
# CHECK: apply_patterns
|
|
# CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
|
|
vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
|
|
vector.ApplyRankReducingSubviewPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
|
|
vector.ApplyTransferPermutationPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_broadcast
|
|
vector.ApplyLowerBroadcastPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_masks
|
|
vector.ApplyLowerMasksPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_masked_transfers
|
|
vector.ApplyLowerMaskedTransfersPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.materialize_masks
|
|
vector.ApplyMaterializeMasksPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_outerproduct
|
|
vector.ApplyLowerOuterProductPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_gather
|
|
vector.ApplyLowerGatherPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.unroll_from_elements
|
|
vector.ApplyUnrollFromElementsPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_scan
|
|
vector.ApplyLowerScanPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_shape_cast
|
|
vector.ApplyLowerShapeCastPatternsOp()
|
|
|
|
|
|
@run_apply_patterns
|
|
def configurable_patterns():
|
|
# CHECK-LABEL: TEST: configurable_patterns
|
|
# CHECK: apply_patterns
|
|
# CHECK: transform.apply_patterns.vector.lower_transfer
|
|
# CHECK-SAME: max_transfer_rank = 4
|
|
vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
|
|
# CHECK: transform.apply_patterns.vector.transfer_to_scf
|
|
# CHECK-SAME: max_transfer_rank = 3
|
|
# CHECK-SAME: full_unroll = true
|
|
vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
|
|
|
|
|
|
@run_apply_patterns
|
|
def enum_configurable_patterns():
|
|
# CHECK: transform.apply_patterns.vector.lower_contraction
|
|
vector.ApplyLowerContractionPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_contraction
|
|
# CHECK-SAME: lowering_strategy = matmulintrinsics
|
|
vector.ApplyLowerContractionPatternsOp(
|
|
lowering_strategy=vector.VectorContractLowering.Matmul
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_contraction
|
|
# CHECK-SAME: lowering_strategy = parallelarith
|
|
vector.ApplyLowerContractionPatternsOp(
|
|
lowering_strategy=vector.VectorContractLowering.ParallelArith
|
|
)
|
|
|
|
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
|
|
vector.ApplyLowerMultiReductionPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
|
|
# This is the default mode, not printed.
|
|
vector.ApplyLowerMultiReductionPatternsOp(
|
|
lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
|
|
# CHECK-SAME: lowering_strategy = innerreduction
|
|
vector.ApplyLowerMultiReductionPatternsOp(
|
|
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
|
|
)
|
|
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
vector.ApplyLowerTransposePatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# This is the default strategy, not printed.
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.EltWise
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# CHECK-SAME: lowering_strategy = flat_transpose
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.Flat
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# CHECK-SAME: lowering_strategy = shuffle_1d
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.Shuffle1D
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# CHECK-SAME: lowering_strategy = shuffle_16x16
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.Shuffle16x16
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.lower_transpose
|
|
# CHECK-SAME: lowering_strategy = flat_transpose
|
|
# CHECK-SAME: avx2_lowering_strategy = true
|
|
vector.ApplyLowerTransposePatternsOp(
|
|
lowering_strategy=vector.VectorTransposeLowering.Flat,
|
|
avx2_lowering_strategy=True,
|
|
)
|
|
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
vector.ApplySplitTransferFullPartialPatternsOp()
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
# CHECK-SAME: split_transfer_strategy = none
|
|
vector.ApplySplitTransferFullPartialPatternsOp(
|
|
split_transfer_strategy=vector.VectorTransferSplit.None_
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
|
|
vector.ApplySplitTransferFullPartialPatternsOp(
|
|
split_transfer_strategy=vector.VectorTransferSplit.VectorTransfer
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
# This is the default mode, not printed.
|
|
vector.ApplySplitTransferFullPartialPatternsOp(
|
|
split_transfer_strategy=vector.VectorTransferSplit.LinalgCopy
|
|
)
|
|
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
|
|
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
|
|
vector.ApplySplitTransferFullPartialPatternsOp(
|
|
split_transfer_strategy=vector.VectorTransferSplit.ForceInBounds
|
|
)
|