llvm-project/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
Matthias Springer 206fad0e21
[mlir][NFC] Mark type converter in populate... functions as const (#111250)
This commit marks the type converter in `populate...` functions as
`const`. This is useful for debugging.

Patterns already take a `const` type converter. However, some
`populate...` functions do not only add new patterns, but also add
additional type conversion rules. That makes it difficult to find the
place where a type conversion was added in the code base. With this
change, all `populate...` functions that only populate pattern now have
a `const` type converter. Programmers can then conclude from the
function signature that these functions do not register any new type
conversion rules.

Also some minor cleanups around the 1:N dialect conversion
infrastructure, which did not always pass the type converter as a
`const` object internally.
2024-10-05 21:32:40 +02:00

276 lines
12 KiB
C++

//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::arm_sve;
template <typename OpTy>
class ForwardOperands : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
return rewriter.notifyMatchFailure(op, "operand types already match");
rewriter.modifyOpInPlace(op,
[&]() { op->setOperands(adaptor.getOperands()); });
return success();
}
};
using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
using ScalableMaskedAddFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
ScalableMaskedAddFIntrOp>;
using ScalableMaskedSubIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
ScalableMaskedSubIIntrOp>;
using ScalableMaskedSubFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
ScalableMaskedSubFIntrOp>;
using ScalableMaskedMulIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
ScalableMaskedMulIIntrOp>;
using ScalableMaskedMulFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
ScalableMaskedMulFIntrOp>;
using ScalableMaskedSDivIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
ScalableMaskedSDivIIntrOp>;
using ScalableMaskedUDivIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
ScalableMaskedUDivIIntrOp>;
using ScalableMaskedDivFOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
ScalableMaskedDivFIntrOp>;
namespace {
/// Unrolls a conversion to/from equivalent vector types, to allow using a
/// conversion intrinsic that only supports 1-D vector types.
///
/// Example:
/// ```
/// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
/// ```
/// is rewritten into:
/// ```
/// %cst = arith.constant dense<false> : vector<2x[16]xi1>
/// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
/// %2 = "arm_sve.intr.convert.to.svbool"(%1)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
/// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
/// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
/// %5 = "arm_sve.intr.convert.to.svbool"(%4)
/// : (vector<[4]xi1>) -> vector<[16]xi1>
/// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
/// ```
template <typename Op, typename IntrOp>
struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(Op convertOp, typename Op::Adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = convertOp.getLoc();
auto source = convertOp.getSource();
VectorType sourceType = source.getType();
VectorType resultType = convertOp.getResult().getType();
Value result = rewriter.create<arith::ConstantOp>(
loc, resultType, rewriter.getZeroAttr(resultType));
// We want to iterate over the input vector in steps of the trailing
// dimension. So this creates tile shape where all leading dimensions are 1,
// and the trailing dimension step is the size of the dimension.
SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
tileShape.back() = sourceType.getShape().back();
// Iterate over all scalable mask/predicate slices of the source vector.
for (SmallVector<int64_t> index :
StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
auto extractOrInsertPosition = ArrayRef(index).drop_back();
auto sourceVector = rewriter.create<vector::ExtractOp>(
loc, source, extractOrInsertPosition);
VectorType convertedType =
VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
.setDim(0, resultType.getShape().back());
auto convertedVector =
rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
extractOrInsertPosition);
}
rewriter.replaceOp(convertOp, result);
return success();
}
};
using ConvertToSvboolOpLowering =
SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
using ConvertFromSvboolOpLowering =
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
/// but first input (P1) and result predicates need conversion to/from svbool.
struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
auto loc = pselOp.getLoc();
auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
adaptor.getP1());
auto indexI32 = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getI32Type(), pselOp.getIndex());
auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
pselOp.getP2(), indexI32);
rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
pselOp, adaptor.getP1().getType(), pselIntr);
return success();
}
};
/// Converts `vector.create_mask` ops that match the size of an SVE predicate
/// to the `whilelt` intrinsic. This produces more canonical codegen than the
/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
/// for more details. Note that we can't use (the more general) active.lane.mask
/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
/// `n` is zero (whereas `create_mask` just returns an all-false mask).
struct CreateMaskOpLowering
: public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::CreateMaskOp createMaskOp,
vector::CreateMaskOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto maskType = createMaskOp.getVectorType();
if (maskType.getRank() != 1 || !maskType.isScalable())
return rewriter.notifyMatchFailure(createMaskOp, "not 1-D and scalable");
// TODO: Support masks which are multiples of SVE predicates.
auto maskBaseSize = maskType.getDimSize(0);
if (maskBaseSize < 2 || maskBaseSize > 16 ||
!llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
return rewriter.notifyMatchFailure(createMaskOp,
"not SVE predicate-sized");
auto loc = createMaskOp.getLoc();
auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type());
rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
adaptor.getOperands()[0]);
return success();
}
};
} // namespace
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
// Populate conversion patterns
// clang-format off
patterns.add<ForwardOperands<func::CallOp>,
ForwardOperands<func::CallIndirectOp>,
ForwardOperands<func::ReturnOp>>(converter,
&converter.getContext());
patterns.add<SdotOpLowering,
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
ScalableMaskedSubFOpLowering,
ScalableMaskedMulIOpLowering,
ScalableMaskedMulFOpLowering,
ScalableMaskedSDivIOpLowering,
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering,
ConvertToSvboolOpLowering,
ConvertFromSvboolOpLowering,
ZipX2OpLowering,
ZipX4OpLowering,
PselOpLowering>(converter);
// Add vector.create_mask conversion with a high benefit as it produces much
// nicer code than the generic lowering.
patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
// clang-format on
}
void mlir::configureArmSVELegalizeForExportTarget(
LLVMConversionTarget &target) {
// clang-format off
target.addLegalOp<SdotIntrOp,
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
ScalableMaskedSubFIntrOp,
ScalableMaskedMulIIntrOp,
ScalableMaskedMulFIntrOp,
ScalableMaskedSDivIIntrOp,
ScalableMaskedUDivIIntrOp,
ScalableMaskedDivFIntrOp,
ConvertToSvboolIntrOp,
ConvertFromSvboolIntrOp,
ZipX2IntrOp,
ZipX4IntrOp,
PselIntrOp,
WhileLTIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
UmmlaOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
ScalableMaskedSubFOp,
ScalableMaskedMulIOp,
ScalableMaskedMulFOp,
ScalableMaskedSDivIOp,
ScalableMaskedUDivIOp,
ScalableMaskedDivFOp,
ConvertToSvboolOp,
ConvertFromSvboolOp,
ZipX2Op,
ZipX4Op>();
// clang-format on
}