Tres Popp 4639a85f94 [mlir] Add math.roundeven and llvm.intr.roundeven
This is similar to math.round, but rounds to even instead of rounding away from
zero in the case of halfway values. This CL also adds lowerings to libm and
to the LLVM intrinsic.

Differential Revision: https://reviews.llvm.org/D132375
2022-08-25 13:39:01 +02:00

302 lines
12 KiB
C++

//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir;
namespace {
using AbsFOpLowering = VectorConvertToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
using CopySignOpLowering =
VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
using CtPopFOpLowering =
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using FloorOpLowering =
VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
using Log10OpLowering =
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using RoundEvenOpLowering =
VectorConvertToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
using RoundOpLowering =
VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
template <typename MathOp, typename LLVMOp>
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
LogicalResult
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto boolZero = rewriter.getBoolAttr(false);
if (!operandType.template isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp zero = rewriter.create<LLVM::ConstantOp>(loc, boolZero);
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
zero);
return success();
}
auto vectorType = resultType.template dyn_cast<VectorType>();
if (!vectorType)
return failure();
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
LLVM::ConstantOp zero =
rewriter.create<LLVM::ConstantOp>(loc, boolZero);
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
zero);
},
rewriter);
}
};
using CountLeadingZerosOpLowering =
IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
using CountTrailingZerosOpLowering =
IntOpWithFlagLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(operandType)) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
auto exp =
rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
},
rewriter);
}
};
// A `log1p` is converted into `log(1 + ...)`.
struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return rewriter.notifyMatchFailure(op, "unsupported operand type");
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one =
LLVM::isCompatibleVectorType(operandType)
? rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(resultType.cast<ShapedType>(),
floatOne))
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
adaptor.getOperand());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
operands[0]);
return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
},
rewriter);
}
};
// A `rsqrt` is converted into `1 / sqrt`.
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(operandType)) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return failure();
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
auto sqrt =
rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
},
rewriter);
}
};
struct ConvertMathToLLVMPass
: public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
ConvertMathToLLVMPass() = default;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateMathToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AbsFOpLowering,
AbsIOpLowering,
CeilOpLowering,
CopySignOpLowering,
CosOpLowering,
CountLeadingZerosOpLowering,
CountTrailingZerosOpLowering,
CtPopFOpLowering,
Exp2OpLowering,
ExpM1OpLowering,
ExpOpLowering,
FloorOpLowering,
FmaOpLowering,
Log10OpLowering,
Log1pOpLowering,
Log2OpLowering,
LogOpLowering,
PowFOpLowering,
RoundEvenOpLowering,
RoundOpLowering,
RsqrtOpLowering,
SinOpLowering,
SqrtOpLowering
>(converter);
// clang-format on
}
std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
return std::make_unique<ConvertMathToLLVMPass>();
}