llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
Markus Böck cd4ca2d7f9 [mlir] Port Conversion Passes to LLVM to use TableGen generated constructors and options
See https://github.com/llvm/llvm-project/issues/57475 for more context.

Using auto-generated constructors and options has significant advantages:
* It forces a uniform style and expectation for consuming a pass
* It allows to very easily add, remove or change options to a pass by simply making the changes in TableGen
* Its less code

This patch in particular ports all the conversion passes which lower to LLVM to use the auto generated constructors and options. For the most part, care was taken so that auto generated constructor functions have the same name as they previously did. Only following slight breaking changes (which I consider as worth the churn) have been made:
* `mlir::cf::createConvertControlFlowToLLVMPass` has been moved to the `mlir` namespace. This is consistent with basically all conversion passes
* `createGpuToLLVMConversionPass` now takes a proper options struct array for its pass options. The pass options are now also autogenerated.
* `LowerVectorToLLVMOptions` has been replaced by the autogenerated `ConvertVectorToLLVMPassOptions` which is automatically kept up to date by TableGen
* I had to move one function in the GPU to LLVM lowering as it is used as default value for an option.
* All passes that previously returned `unique_ptr<OperationPass<...>>` now simply return `unique_ptr<Pass>`

Differential Revision: https://reviews.llvm.org/D143773
2023-02-10 20:47:18 +01:00

114 lines
4.4 KiB
C++

//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
using namespace mlir::vector;
namespace {
struct LowerVectorToLLVMPass
: public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
using Base::Base;
// Override explicitly to allow conditional dialect dependence.
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<memref::MemRefDialect>();
if (armNeon)
registry.insert<arm_neon::ArmNeonDialect>();
if (armSVE)
registry.insert<arm_sve::ArmSVEDialect>();
if (amx)
registry.insert<amx::AMXDialect>();
if (x86Vector)
registry.insert<x86vector::X86VectorDialect>();
}
void runOnOperation() override;
};
} // namespace
void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
RewritePatternSet patterns(&getContext());
populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
populateVectorTransferLoweringPatterns(patterns);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, force32BitVectorIndices);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
if (armNeon) {
// TODO: we may or may not want to include in-dialect lowering to
// LLVM-compatible operations here. So far, all operations in the dialect
// can be translated to LLVM IR so there is no conversion necessary.
target.addLegalDialect<arm_neon::ArmNeonDialect>();
}
if (armSVE) {
configureArmSVELegalizeForExportTarget(target);
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
}
if (amx) {
configureAMXLegalizeForExportTarget(target);
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
}
if (x86Vector) {
configureX86VectorLegalizeForExportTarget(target);
populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
}
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}