[mlir][Vector] Lower vector.to_elements to LLVM (#145766)

Only elements with at least one use are lowered to `llvm.extractelement`
op.
This commit is contained in:
Diego Caballero 2025-06-26 10:36:08 -07:00 committed by GitHub
parent abf8e25ac7
commit 7842e9eada
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 73 additions and 2 deletions

View File

@ -1985,6 +1985,38 @@ struct VectorFromElementsLowering
} }
}; };
/// Conversion pattern for a `vector.to_elements`.
struct VectorToElementsLowering
: public ConvertOpToLLVMPattern<vector::ToElementsOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = toElementsOp.getLoc();
auto idxType = typeConverter->convertType(rewriter.getIndexType());
Value source = adaptor.getSource();
SmallVector<Value> results(toElementsOp->getNumResults());
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
// Create an extractelement operation only for results that are not dead.
if (element.use_empty())
continue;
auto constIdx = rewriter.create<LLVM::ConstantOp>(
loc, idxType, rewriter.getIntegerAttr(idxType, idx));
auto llvmType = typeConverter->convertType(element.getType());
Value result = rewriter.create<LLVM::ExtractElementOp>(loc, llvmType,
source, constIdx);
results[idx] = result;
}
rewriter.replaceOp(toElementsOp, results);
return success();
}
};
/// Conversion pattern for vector.step. /// Conversion pattern for vector.step.
struct VectorScalableStepOpLowering struct VectorScalableStepOpLowering
: public ConvertOpToLLVMPattern<vector::StepOp> { : public ConvertOpToLLVMPattern<vector::StepOp> {
@ -2035,7 +2067,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering, MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering, VectorDeinterleaveOpLowering, VectorFromElementsLowering,
VectorScalableStepOpLowering>(converter); VectorToElementsLowering, VectorScalableStepOpLowering>(
converter);
} }
void mlir::populateVectorToLLVMMatrixConversionPatterns( void mlir::populateVectorToLLVMMatrixConversionPatterns(

View File

@ -2421,6 +2421,44 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
// ----- // -----
//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @to_elements_no_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}
// -----
// CHECK-LABEL: func.func @to_elements_dead_elements
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
%0:4 = vector.to_elements %a : vector<4xf32>
return %0#1, %0#3 : f32, f32
}
// -----
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// vector.step // vector.step
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//