This patch adds two LLVM intrinsics to the ArmSME dialect: * llvm.aarch64.sme.za.enable * llvm.aarch64.sme.za.disable for enabling the ZA storage array [1], as well as patterns for inserting them during legalization to LLVM at the start and end of functions if the function has the 'arm_za' attribute (D152695). In the future ZA should probably be automatically enabled/disabled when lowering from vector to SME, but this should be sufficient for now at least until we have patterns lowering to SME instructions that use ZA. N.B. The backend function attribute 'aarch64_pstate_za_new' can be used manage ZA state (as was originally tried in D152694), but it emits calls to the following SME support routines [2] for the lazy-save mechanism [3]: * __arm_tpidr2_restore * __arm_tpidr2_save These will soon be added to compiler-rt but there's currently no public implementation, and using this attribute would introduce an MLIR dependency on compiler-rt. Furthermore, this mechanism is for routines with ZA enabled calling other routines with it also enabled. We can choose not to enable ZA in the compiler when this is case. Depends on D152695 [1] https://developer.arm.com/documentation/ddi0616/aa [2] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#sme-support-routines [3] https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#the-za-lazy-saving-scheme Reviewed By: awarzynski, dcaballe Differential Revision: https://reviews.llvm.org/D153050
126 lines
4.9 KiB
C++
126 lines
4.9 KiB
C++
//===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Dialect/AMX/AMXDialect.h"
|
|
#include "mlir/Dialect/AMX/Transforms.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
|
|
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
|
|
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
|
|
#include "mlir/Dialect/ArmSVE/Transforms.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
|
#include "mlir/Dialect/X86Vector/Transforms.h"
|
|
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTVECTORTOLLVMPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
|
|
namespace {
|
|
struct LowerVectorToLLVMPass
|
|
: public impl::ConvertVectorToLLVMPassBase<LowerVectorToLLVMPass> {
|
|
|
|
using Base::Base;
|
|
|
|
// Override explicitly to allow conditional dialect dependence.
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<LLVM::LLVMDialect>();
|
|
registry.insert<arith::ArithDialect>();
|
|
registry.insert<memref::MemRefDialect>();
|
|
if (armNeon)
|
|
registry.insert<arm_neon::ArmNeonDialect>();
|
|
if (armSVE)
|
|
registry.insert<arm_sve::ArmSVEDialect>();
|
|
if (armSME)
|
|
registry.insert<arm_sme::ArmSMEDialect>();
|
|
if (amx)
|
|
registry.insert<amx::AMXDialect>();
|
|
if (x86Vector)
|
|
registry.insert<x86vector::X86VectorDialect>();
|
|
}
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void LowerVectorToLLVMPass::runOnOperation() {
|
|
// Perform progressive lowering of operations on slices and
|
|
// all contraction operations. Also applies folding and DCE.
|
|
{
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
populateVectorBroadcastLoweringPatterns(patterns);
|
|
populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions());
|
|
populateVectorMaskOpLoweringPatterns(patterns);
|
|
populateVectorShapeCastLoweringPatterns(patterns);
|
|
populateVectorTransposeLoweringPatterns(patterns,
|
|
VectorTransformsOptions());
|
|
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
|
|
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
// Convert to the LLVM IR dialect.
|
|
LowerToLLVMOptions options(&getContext());
|
|
options.useOpaquePointers = useOpaquePointers;
|
|
LLVMTypeConverter converter(&getContext(), options);
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorMaskMaterializationPatterns(patterns, force32BitVectorIndices);
|
|
populateVectorTransferLoweringPatterns(patterns);
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
populateVectorToLLVMConversionPatterns(
|
|
converter, patterns, reassociateFPReductions, force32BitVectorIndices);
|
|
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
|
|
|
// Architecture specific augmentations.
|
|
LLVMConversionTarget target(getContext());
|
|
target.addLegalDialect<arith::ArithDialect>();
|
|
target.addLegalDialect<memref::MemRefDialect>();
|
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
|
if (armNeon) {
|
|
// TODO: we may or may not want to include in-dialect lowering to
|
|
// LLVM-compatible operations here. So far, all operations in the dialect
|
|
// can be translated to LLVM IR so there is no conversion necessary.
|
|
target.addLegalDialect<arm_neon::ArmNeonDialect>();
|
|
}
|
|
if (armSVE) {
|
|
configureArmSVELegalizeForExportTarget(target);
|
|
populateArmSVELegalizeForLLVMExportPatterns(converter, patterns);
|
|
}
|
|
if (armSME) {
|
|
configureArmSMELegalizeForExportTarget(target);
|
|
populateArmSMELegalizeForLLVMExportPatterns(converter, patterns);
|
|
}
|
|
if (amx) {
|
|
configureAMXLegalizeForExportTarget(target);
|
|
populateAMXLegalizeForLLVMExportPatterns(converter, patterns);
|
|
}
|
|
if (x86Vector) {
|
|
configureX86VectorLegalizeForExportTarget(target);
|
|
populateX86VectorLegalizeForLLVMExportPatterns(converter, patterns);
|
|
}
|
|
|
|
if (failed(
|
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|