[mlir][Vector] Move mask materialization patterns to greedy rewrite (#119973)

The mask materialization patterns during `VectorToLLVM` are rewrite
patterns. They should run as part of the greedy pattern rewrite and not
the dialect conversion. (Rewrite patterns and conversion patterns are
not generally compatible.)

The current combination of rewrite patterns and conversion patterns
triggered an edge case when merging the 1:1 and 1:N dialect conversions.
This commit is contained in:
Matthias Springer 2024-12-17 11:26:31 +01:00 committed by GitHub
parent a1f5fe8c85
commit 8cd8b5079b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 44 additions and 51 deletions

View File

@ -61,8 +61,8 @@ struct ConvertVectorToLLVMPass
} // namespace
void ConvertVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
// Perform progressive lowering of operations on slices and all contraction
// operations. Also materializes masks, applies folding and DCE.
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
@ -76,6 +76,8 @@ void ConvertVectorToLLVMPass::runOnOperation() {
VectorTransformsOptions());
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
populateVectorMaskMaterializationPatterns(patterns,
force32BitVectorIndices);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
@ -83,7 +85,6 @@ void ConvertVectorToLLVMPass::runOnOperation() {
LowerToLLVMOptions options(&getContext());
LLVMTypeConverter converter(&getContext(), options);
RewritePatternSet patterns(&getContext());
populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(

View File

@ -7,7 +7,7 @@
// CMP32: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i32
// CMP32: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi32>
// CMP32: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi32>
// CMP32: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi32>
// CMP32: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi32>
// CMP32: return %[[T4]] : vector<11xi1>
// CMP64-LABEL: @genbool_var_1d(
@ -16,7 +16,7 @@
// CMP64: %[[T1:.*]] = arith.index_cast %[[ARG]] : index to i64
// CMP64: %[[T2:.*]] = llvm.insertelement %[[T1]], %{{.*}}[%{{.*}} : i32] : vector<11xi64>
// CMP64: %[[T3:.*]] = llvm.shufflevector %[[T2]], %{{.*}} [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] : vector<11xi64>
// CMP64: %[[T4:.*]] = arith.cmpi slt, %[[T0]], %[[T3]] : vector<11xi64>
// CMP64: %[[T4:.*]] = arith.cmpi sgt, %[[T3]], %[[T0]] : vector<11xi64>
// CMP64: return %[[T4]] : vector<11xi1>
func.func @genbool_var_1d(%arg0: index) -> vector<11xi1> {

View File

@ -3097,7 +3097,7 @@ func.func @create_mask_0d(%num_elems : index) -> vector<i1> {
// CHECK: %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
// CHECK: %[[BOUNDS:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
// CHECK: %[[BOUNDS_CAST:.*]] = builtin.unrealized_conversion_cast %[[BOUNDS]] : vector<1xi32> to vector<i32>
// CHECK: %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS_CAST]] : vector<i32>
// CHECK: %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS_CAST]], %[[INDICES]] : vector<i32>
// CHECK: return %[[RESULT]] : vector<i1>
// -----
@ -3113,7 +3113,7 @@ func.func @create_mask_1d(%num_elems : index) -> vector<4xi1> {
// CHECK: %[[NUM_ELEMS_i32:.*]] = arith.index_cast %[[NUM_ELEMS]] : index to i32
// CHECK: %[[BOUNDS_INSERT:.*]] = llvm.insertelement %[[NUM_ELEMS_i32]]
// CHECK: %[[BOUNDS:.*]] = llvm.shufflevector %[[BOUNDS_INSERT]]
// CHECK: %[[RESULT:.*]] = arith.cmpi slt, %[[INDICES]], %[[BOUNDS]] : vector<4xi32>
// CHECK: %[[RESULT:.*]] = arith.cmpi sgt, %[[BOUNDS]], %[[INDICES]] : vector<4xi32>
// CHECK: return %[[RESULT]] : vector<4xi1>
// -----

View File

@ -14,30 +14,28 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
// CHECK-LABEL: func @transfer_read_write_1d
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<17xf32>
// CHECK: %[[C7:.*]] = arith.constant 7.0
// 1. Create pass-through vector.
// CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<17xf32>
//
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
//
// 3. Let dim be the memref dimension, compute the in-bound index (dim - offset)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex:.*]] = arith.constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
//
// 3. Create bound vector to compute in-bound mask:
// 4. Create bound vector to compute in-bound mask:
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] :
// CMP32-SAME: index to i32
// CMP64-SAME: index to i64
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]] : vector<17x[[$IDX_TYPE]]>
// CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
// CMP64-SAME: : vector<17xi64>
//
// 4. Create pass-through vector.
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<17xf32>
//
// 5. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@ -48,28 +46,23 @@ func.func @transfer_read_write_1d(%A : memref<?xf32>, %base: index) -> vector<17
// CHECK-SAME: -> vector<17xf32>
//
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
// CHECK: %[[C0_b:.*]] = arith.constant 0 : index
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex_b:.*]] = arith.constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<17x[[$IDX_TYPE]]>
//
// 3. Create bound vector to compute in-bound mask:
// 2. Create bound vector to compute in-bound mask:
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
// CHECK: %[[btrunc_b:.*]] = arith.index_cast %[[BOUND_b]]
// CMP32-SAME: index to i32
// CHECK: %[[boundVecInsert_b:.*]] = llvm.insertelement %[[btrunc_b]]
// CHECK: %[[boundVect_b:.*]] = llvm.shufflevector %[[boundVecInsert_b]]
// CHECK: %[[mask_b:.*]] = arith.cmpi slt, %[[linearIndex_b]],
// CHECK-SAME: %[[boundVect_b]] : vector<17x[[$IDX_TYPE]]>
// CHECK: %[[mask_b:.*]] = arith.cmpi sgt, %[[boundVect_b]],
// CHECK-SAME: %[[linearIndex]] : vector<17x[[$IDX_TYPE]]>
//
// 4. Bitcast to vector form.
// 3. Bitcast to vector form.
// CHECK: %[[gep_b:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
//
// 5. Rewrite as a masked write.
// 4. Rewrite as a masked write.
// CHECK: llvm.intr.masked.store %[[loaded]], %[[gep_b]], %[[mask_b]]
// CHECK-SAME: {alignment = 4 : i32} :
// CHECK-SAME: vector<17xf32>, vector<17xi1> into !llvm.ptr
@ -87,17 +80,18 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
// CHECK-LABEL: func @transfer_read_write_1d_scalable
// CHECK-SAME: %[[MEM:.*]]: memref<?xf32>,
// CHECK-SAME: %[[BASE:.*]]: index) -> vector<[17]xf32>
// CHECK: %[[C7:.*]] = arith.constant 7.0
// 1. Create pass-through vector.
// CHECK-DAG: %[[PASS_THROUGH:.*]] = arith.constant dense<7.000000e+00> : vector<[17]xf32>
//
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// 2. Let dim be the memref dimension, compute the in-bound index (dim - offset)
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE]] : index
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// 3. Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex:.*]] = llvm.intr.stepvector : vector<[17]x[[$IDX_TYPE]]>
//
// 3. Create bound vector to compute in-bound mask:
// 4. Create bound vector to compute in-bound mask:
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
@ -105,9 +99,6 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
// CHECK-SAME: : vector<[17]x[[$IDX_TYPE]]>
//
// 4. Create pass-through vector.
// CHECK: %[[PASS_THROUGH:.*]] = arith.constant dense<7.{{.*}}> : vector<[17]xf32>
//
// 5. Bitcast to vector form.
// CHECK: %[[gep:.*]] = llvm.getelementptr %{{.*}} :
// CHECK-SAME: (!llvm.ptr, i64) -> !llvm.ptr, f32
@ -118,8 +109,7 @@ func.func @transfer_read_write_1d_scalable(%A : memref<?xf32>, %base: index) ->
// CHECK-SAME: -> vector<[17]xf32>
//
// 1. Let dim be the memref dimension, compute the in-bound index (dim - offset)
// CHECK: %[[C0_b:.*]] = arith.constant 0 : index
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0_b]] : memref<?xf32>
// CHECK: %[[DIM_b:.*]] = memref.dim %[[MEM]], %[[C0]] : memref<?xf32>
// CHECK: %[[BOUND_b:.*]] = arith.subi %[[DIM_b]], %[[BASE]] : index
//
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
@ -197,23 +187,23 @@ func.func @transfer_read_2d_to_1d(%A : memref<?x?xf32>, %base0: index, %base1: i
}
// CHECK-LABEL: func @transfer_read_2d_to_1d
// CHECK-SAME: %[[BASE_0:[a-zA-Z0-9]*]]: index, %[[BASE_1:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
// CHECK: %[[c1:.*]] = arith.constant 1 : index
//
// Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK-DAG: %[[linearIndex:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17x[[$IDX_TYPE]]>
//
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c1]] : memref<?x?xf32>
//
// Compute the in-bound index (dim - offset)
// CHECK: %[[BOUND:.*]] = arith.subi %[[DIM]], %[[BASE_1]] : index
//
// Create a vector with linear indices [ 0 .. vector_length - 1 ].
// CHECK: %[[linearIndex:.*]] = arith.constant dense
// CHECK-SAME: <[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> :
// CHECK-SAME: vector<17x[[$IDX_TYPE]]>
//
// Create bound vector to compute in-bound mask:
// [ 0 .. vector_length - 1 ] < [ dim - offset .. dim - offset ]
// CHECK: %[[btrunc:.*]] = arith.index_cast %[[BOUND]] : index to [[$IDX_TYPE]]
// CHECK: %[[boundVecInsert:.*]] = llvm.insertelement %[[btrunc]]
// CHECK: %[[boundVect:.*]] = llvm.shufflevector %[[boundVecInsert]]
// CHECK: %[[mask:.*]] = arith.cmpi slt, %[[linearIndex]], %[[boundVect]]
// CHECK: %[[mask:.*]] = arith.cmpi sgt, %[[boundVect]], %[[linearIndex]]
func.func @transfer_read_2d_to_1d_scalable(%A : memref<?x?xf32>, %base0: index, %base1: index) -> vector<[17]xf32> {
%f7 = arith.constant 7.0: f32
@ -255,12 +245,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace(%A : memref<?xf32, 3>, %bas
// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xf32>
//
// CHECK: %[[c0:.*]] = arith.constant 0 : index
//
// 1. Check address space for GEP is correct.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
//
// 2. Check address space of the memref is correct.
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
//
// 3. Check address space for GEP is correct.
@ -280,12 +271,13 @@ func.func @transfer_read_write_1d_non_zero_addrspace_scalable(%A : memref<?xf32,
// CHECK-LABEL: func @transfer_read_write_1d_non_zero_addrspace_scalable
// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<[17]xf32>
//
// CHECK: %[[c0:.*]] = arith.constant 0 : index
//
// 1. Check address space for GEP is correct.
// CHECK: %[[gep:.*]] = llvm.getelementptr {{.*}} :
// CHECK-SAME: (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f32
//
// 2. Check address space of the memref is correct.
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM:.*]] = memref.dim %{{.*}}, %[[c0]] : memref<?xf32, 3>
//
// 3. Check address space for GEP is correct.
@ -330,10 +322,10 @@ func.func @transfer_read_1d_inbounds_scalable(%A : memref<?xf32>, %base: index)
// CHECK-LABEL: func @transfer_read_write_1d_mask
// CHECK: %[[mask1:.*]] = arith.constant dense<[false, false, true, false, true]>
// CHECK: %[[cmpi:.*]] = arith.cmpi slt
// CHECK: %[[cmpi:.*]] = arith.cmpi sgt
// CHECK: %[[mask2:.*]] = arith.andi %[[cmpi]], %[[mask1]]
// CHECK: %[[r:.*]] = llvm.intr.masked.load %{{.*}}, %[[mask2]]
// CHECK: %[[cmpi_1:.*]] = arith.cmpi slt
// CHECK: %[[cmpi_1:.*]] = arith.cmpi sgt
// CHECK: %[[mask3:.*]] = arith.andi %[[cmpi_1]], %[[mask1]]
// CHECK: llvm.intr.masked.store %[[r]], %{{.*}}, %[[mask3]]
// CHECK: return %[[r]]