Historically, the Vector to LLVM dialect conversion subsumed the Standard to LLVM dialect conversion patterns. This was necessary because the conversion infrastructure did not have sufficient support for reconciling type conversions. This support is now available. Only keep the patterns related to the Vector dialect in the Vector to LLVM conversion and require type casts operations to be inserted if necessary. These casts will be removed by following conversions if possible. Update integration tests to also run the Standard to LLVM conversion. There is a significant amount of test churn, which is due to (a) unnecessarily strict tests in VectorToLLVM and (b) many patterns actually targeting Standard dialect ops instead of LLVM dialect ops leading to tests actually exercising a Vector->Standard->LLVM conversion. This churn is a good illustration of the reason to make the conversion partial: now the tests only check the code in the Vector to LLVM conversion and will not be randomly broken by changes in Standard to LLVM conversion. Arguably, it may be possible to extract Vector to Standard patterns into a separate pass, but given the ongoing splitting of the Standard dialect, such pass will be short-lived and will require further refactoring. Depends On D95626 Reviewed By: nicolasvasilache, aartbik Differential Revision: https://reviews.llvm.org/D95685
122 lines
4.9 KiB
C++
122 lines
4.9 KiB
C++
//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
|
|
#include "../PassDetail.h"
|
|
|
|
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
|
|
#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
|
|
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
|
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
|
|
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
|
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
|
|
namespace {
|
|
struct LowerVectorToLLVMPass
|
|
: public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
|
|
LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
|
|
this->reassociateFPReductions = options.reassociateFPReductions;
|
|
this->enableIndexOptimizations = options.enableIndexOptimizations;
|
|
this->enableArmNeon = options.enableArmNeon;
|
|
this->enableArmSVE = options.enableArmSVE;
|
|
this->enableAVX512 = options.enableAVX512;
|
|
}
|
|
// Override explicitly to allow conditional dialect dependence.
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<LLVM::LLVMDialect>();
|
|
if (enableArmNeon)
|
|
registry.insert<LLVM::LLVMArmNeonDialect>();
|
|
if (enableArmSVE)
|
|
registry.insert<LLVM::LLVMArmSVEDialect>();
|
|
if (enableAVX512)
|
|
registry.insert<LLVM::LLVMAVX512Dialect>();
|
|
}
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void LowerVectorToLLVMPass::runOnOperation() {
|
|
// Perform progressive lowering of operations on slices and
|
|
// all contraction operations. Also applies folding and DCE.
|
|
{
|
|
OwningRewritePatternList patterns;
|
|
populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
|
|
populateVectorSlicesLoweringPatterns(patterns, &getContext());
|
|
populateVectorContractLoweringPatterns(patterns, &getContext());
|
|
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
// Convert to the LLVM IR dialect.
|
|
LLVMTypeConverter converter(&getContext());
|
|
OwningRewritePatternList patterns;
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
populateVectorToLLVMConversionPatterns(
|
|
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
|
|
// Architecture specific augmentations.
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalOp<LLVM::DialectCastOp>();
|
|
target.addLegalDialect<StandardOpsDialect>();
|
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
|
if (enableArmNeon) {
|
|
target.addLegalDialect<LLVM::LLVMArmNeonDialect>();
|
|
target.addIllegalDialect<arm_neon::ArmNeonDialect>();
|
|
populateArmNeonToLLVMConversionPatterns(converter, patterns);
|
|
}
|
|
if (enableArmSVE) {
|
|
target.addLegalDialect<LLVM::LLVMArmSVEDialect>();
|
|
target.addIllegalDialect<arm_sve::ArmSVEDialect>();
|
|
auto hasScalableVectorType = [](TypeRange types) {
|
|
for (Type type : types)
|
|
if (type.isa<arm_sve::ScalableVectorType>())
|
|
return true;
|
|
return false;
|
|
};
|
|
// Remove any ArmSVE-specific types from function signatures and results.
|
|
populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
|
|
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
|
|
return !hasScalableVectorType(op.getType().getInputs()) &&
|
|
!hasScalableVectorType(op.getType().getResults());
|
|
});
|
|
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
|
|
[hasScalableVectorType](Operation *op) {
|
|
return !hasScalableVectorType(op->getOperandTypes()) &&
|
|
!hasScalableVectorType(op->getResultTypes());
|
|
});
|
|
populateArmSVEToLLVMConversionPatterns(converter, patterns);
|
|
}
|
|
if (enableAVX512) {
|
|
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
|
|
target.addIllegalDialect<avx512::AVX512Dialect>();
|
|
populateAVX512ToLLVMConversionPatterns(converter, patterns);
|
|
}
|
|
|
|
if (failed(
|
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
|
|
return std::make_unique<LowerVectorToLLVMPass>(options);
|
|
}
|