[mlir][Vector] Lower vector.to_elements
to LLVM (#145766)
Only elements with at least one use are lowered to `llvm.extractelement` op.
This commit is contained in:
parent
abf8e25ac7
commit
7842e9eada
@ -1985,6 +1985,38 @@ struct VectorFromElementsLowering
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Conversion pattern for a `vector.to_elements`.
|
||||||
|
struct VectorToElementsLowering
|
||||||
|
: public ConvertOpToLLVMPattern<vector::ToElementsOp> {
|
||||||
|
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
Location loc = toElementsOp.getLoc();
|
||||||
|
auto idxType = typeConverter->convertType(rewriter.getIndexType());
|
||||||
|
Value source = adaptor.getSource();
|
||||||
|
|
||||||
|
SmallVector<Value> results(toElementsOp->getNumResults());
|
||||||
|
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
|
||||||
|
// Create an extractelement operation only for results that are not dead.
|
||||||
|
if (element.use_empty())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
auto constIdx = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
loc, idxType, rewriter.getIntegerAttr(idxType, idx));
|
||||||
|
auto llvmType = typeConverter->convertType(element.getType());
|
||||||
|
|
||||||
|
Value result = rewriter.create<LLVM::ExtractElementOp>(loc, llvmType,
|
||||||
|
source, constIdx);
|
||||||
|
results[idx] = result;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(toElementsOp, results);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Conversion pattern for vector.step.
|
/// Conversion pattern for vector.step.
|
||||||
struct VectorScalableStepOpLowering
|
struct VectorScalableStepOpLowering
|
||||||
: public ConvertOpToLLVMPattern<vector::StepOp> {
|
: public ConvertOpToLLVMPattern<vector::StepOp> {
|
||||||
@ -2035,7 +2067,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
|||||||
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
|
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
|
||||||
MaskedReductionOpConversion, VectorInterleaveOpLowering,
|
MaskedReductionOpConversion, VectorInterleaveOpLowering,
|
||||||
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
|
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
|
||||||
VectorScalableStepOpLowering>(converter);
|
VectorToElementsLowering, VectorScalableStepOpLowering>(
|
||||||
|
converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||||
|
@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
|
|||||||
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||||
// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
|
// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
|
||||||
// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
|
// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
|
||||||
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
|
||||||
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
|
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
|
||||||
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
|
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
|
||||||
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
|
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
|
||||||
@ -2421,6 +2421,44 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// vector.to_elements
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @to_elements_no_dead_elements
|
||||||
|
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
|
||||||
|
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
|
||||||
|
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
|
||||||
|
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||||
|
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
|
||||||
|
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
|
||||||
|
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
|
||||||
|
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
|
||||||
|
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
|
||||||
|
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
|
||||||
|
func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
|
||||||
|
%0:4 = vector.to_elements %a : vector<4xf32>
|
||||||
|
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @to_elements_dead_elements
|
||||||
|
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
|
||||||
|
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
|
||||||
|
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
|
||||||
|
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
|
||||||
|
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
|
||||||
|
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
|
||||||
|
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
|
||||||
|
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
|
||||||
|
func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
|
||||||
|
%0:4 = vector.to_elements %a : vector<4xf32>
|
||||||
|
return %0#1, %0#3 : f32, f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// vector.step
|
// vector.step
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
Loading…
x
Reference in New Issue
Block a user