llvm-project/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
Alex Zinenko a776942ba1 [mlir] squash LLVM_AVX512 dialect into AVX512
The dialect separation was introduced to demarkate ops operating in different
type systems. This is no longer the case after the LLVM dialect has migrated to
using built-in vector types, so the original reason for separation is no longer
valid. Squash the two dialects into one.

The code size decrease isn't quite large: the ops originally in LLVM_AVX512 are
preserved because they match LLVM IR intrinsics specialized for vector element
bitwidth. However, it is still conceptually beneficial to have only one
dialect. I originally considered to use Tablegen multiclasses to define both
the type-polymorphic op and its two intrinsic-related instantiations, but
decided against it given both the complexity of the required Tablegen input and
its dissimilarity with the rest of ODS-defined ops, both potentially resulting
in very poor maintainability.

Depends On D98327

Reviewed By: nicolasvasilache, springerm

Differential Revision: https://reviews.llvm.org/D98328
2021-03-10 13:07:26 +01:00

119 lines
4.8 KiB
C++

//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h"
#include "mlir/Dialect/AVX512/Transforms.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
namespace {
struct LowerVectorToLLVMPass
: public ConvertVectorToLLVMBase<LowerVectorToLLVMPass> {
LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
this->reassociateFPReductions = options.reassociateFPReductions;
this->enableIndexOptimizations = options.enableIndexOptimizations;
this->enableArmNeon = options.enableArmNeon;
this->enableArmSVE = options.enableArmSVE;
this->enableAVX512 = options.enableAVX512;
}
// Override explicitly to allow conditional dialect dependence.
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
if (enableArmNeon)
registry.insert<arm_neon::ArmNeonDialect>();
if (enableArmSVE)
registry.insert<LLVM::LLVMArmSVEDialect>();
if (enableAVX512)
registry.insert<avx512::AVX512Dialect>();
}
void runOnOperation() override;
};
} // namespace
void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
OwningRewritePatternList patterns;
populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
populateVectorSlicesLoweringPatterns(patterns, &getContext());
populateVectorContractLoweringPatterns(patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
// Architecture specific augmentations.
LLVMConversionTarget target(getContext());
target.addLegalOp<LLVM::DialectCastOp>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
if (enableArmNeon) {
// TODO: we may or may not want to include in-dialect lowering to
// LLVM-compatible operations here. So far, all operations in the dialect
// can be translated to LLVM IR so there is no conversion necessary.
target.addLegalDialect<arm_neon::ArmNeonDialect>();
}
if (enableArmSVE) {
target.addLegalDialect<LLVM::LLVMArmSVEDialect>();
target.addIllegalDialect<arm_sve::ArmSVEDialect>();
auto hasScalableVectorType = [](TypeRange types) {
for (Type type : types)
if (type.isa<arm_sve::ScalableVectorType>())
return true;
return false;
};
// Remove any ArmSVE-specific types from function signatures and results.
populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
});
target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
[hasScalableVectorType](Operation *op) {
return !hasScalableVectorType(op->getOperandTypes()) &&
!hasScalableVectorType(op->getResultTypes());
});
populateArmSVEToLLVMConversionPatterns(converter, patterns);
}
if (enableAVX512) {
configureAVX512LegalizeForExportTarget(target);
populateAVX512LegalizeForLLVMExportPatterns(converter, patterns);
}
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>>
mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) {
return std::make_unique<LowerVectorToLLVMPass>(options);
}