llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
Alex Zinenko ba87f99168 [mlir] make vector to llvm conversion truly partial
Historically, the Vector to LLVM dialect conversion subsumed the Standard to
LLVM dialect conversion patterns. This was necessary because the conversion
infrastructure did not have sufficient support for reconciling type
conversions. This support is now available. Only keep the patterns related to
the Vector dialect in the Vector to LLVM conversion and require type casts
operations to be inserted if necessary. These casts will be removed by
following conversions if possible. Update integration tests to also run the
Standard to LLVM conversion.

There is a significant amount of test churn, which is due to (a) unnecessarily
strict tests in VectorToLLVM and (b) many patterns actually targeting Standard
dialect ops instead of LLVM dialect ops leading to tests actually exercising a
Vector->Standard->LLVM conversion. This churn is a good illustration of the
reason to make the conversion partial: now the tests only check the code in the
Vector to LLVM conversion and will not be randomly broken by changes in
Standard to LLVM conversion.

Arguably, it may be possible to extract Vector to Standard patterns into a
separate pass, but given the ongoing splitting of the Standard dialect, such
pass will be short-lived and will require further refactoring.

Depends On D95626

Reviewed By: nicolasvasilache, aartbik

Differential Revision: https://reviews.llvm.org/D95685
2021-02-04 11:33:24 +01:00

122 lines
4.9 KiB
C++

//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.h"
#include "mlir/Conversion/ArmNeonToLLVM/ArmNeonToLLVM.h"
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmNeonDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
namespace {
struct LowerVectorToLLVMPass
: public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
this->reassociateFPReductions = options.reassociateFPReductions;
this->enableIndexOptimizations = options.enableIndexOptimizations;
this->enableArmNeon = options.enableArmNeon;
this->enableArmSVE = options.enableArmSVE;
this->enableAVX512 = options.enableAVX512;
}
// Override explicitly to allow conditional dialect dependence.
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
if (enableArmNeon)
registry.insert<LLVM::LLVMArmNeonDialect>();
if (enableArmSVE)
registry.insert<LLVM::LLVMArmSVEDialect>();
if (enableAVX512)
registry.insert<LLVM::LLVMAVX512Dialect>();
}
void runOnOperation() override;
};
} // namespace
void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
OwningRewritePatternList patterns;
populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
populateVectorSlicesLoweringPatterns(patterns, &getContext());
populateVectorContractLoweringPatterns(patterns, &getContext());
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
target.addLegalOp<LLVM::DialectCastOp>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
if (enableArmNeon) {
target.addLegalDialect<LLVM::LLVMArmNeonDialect>();
target.addIllegalDialect<arm_neon::ArmNeonDialect>();
populateArmNeonToLLVMConversionPatterns(converter, patterns);
}
if (enableArmSVE) {
target.addLegalDialect<LLVM::LLVMArmSVEDialect>();
target.addIllegalDialect<arm_sve::ArmSVEDialect>();
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
return true;
return false;
};
// Remove any ArmSVE-specific types from function signatures and results.
populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
});
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
[hasScalableVectorType](Operation *op) {
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
populateArmSVEToLLVMConversionPatterns(converter, patterns);
}
if (enableAVX512) {
target.addLegalDialect<LLVM::LLVMAVX512Dialect>();
target.addIllegalDialect<avx512::AVX512Dialect>();
populateAVX512ToLLVMConversionPatterns(converter, patterns);
}
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
return std::make_unique<LowerVectorToLLVMPass>(options);
}