llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
thomasraoux be8e2801a4 [mlir][vector][NFC] split TransposeOp lowerning out of contractLowering
Move TransposeOp lowering in its own populate function as in some cases
it is better to keep it during ContractOp lowering to better
canonicalize it rather than emiting scalar insert/extract.

Differential Revision: https://reviews.llvm.org/D101647
2021-05-03 10:23:45 -07:00

114 lines
4.5 KiB
C++

//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
namespace {
struct LowerVectorToLLVMPass
: public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
this->reassociateFPReductions = options.reassociateFPReductions;
this->enableIndexOptimizations = options.enableIndexOptimizations;
this->enableArmNeon = options.enableArmNeon;
this->enableArmSVE = options.enableArmSVE;
this->enableAMX = options.enableAMX;
this->enableX86Vector = options.enableX86Vector;
}
// Override explicitly to allow conditional dialect dependence.
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
registry.insert<memref::MemRefDialect>();
if (enableArmNeon)
registry.insert<arm_neon::ArmNeonDialect>();
if (enableArmSVE)
registry.insert<arm_sve::ArmSVEDialect>();
if (enableAMX)
registry.insert<amx::AMXDialect>();
if (enableX86Vector)
registry.insert<x86vector::X86VectorDialect>();
}
void runOnOperation() override;
};
} // namespace
void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorSlicesLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
RewritePatternSet patterns(&getContext());
populateVectorMaskMaterializationPatterns(patterns, enableIndexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns,
reassociateFPReductions);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
target.addLegalOp<LLVM::DialectCastOp>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
if (enableArmNeon) {
// TODO: we may or may not want to include in-dialect lowering to
// LLVM-compatible operations here. So far, all operations in the dialect
// can be translated to LLVM IR so there is no conversion necessary.
target.addLegalDialect<arm_neon::ArmNeonDialect>();
}
if (enableArmSVE) {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (enableAMX) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
}
if (enableX86Vector) {
configureX86VectorLegalizeForExportTarget(target);
populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
}
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
return std::make_unique<LowerVectorToLLVMPass>(options);
}