
This is similar to math.round, but rounds to even instead of rounding away from zero in the case of halfway values. This CL also adds lowerings to libm and to the LLVM intrinsic. Differential Revision: https://reviews.llvm.org/D132375
302 lines
12 KiB
C++
302 lines
12 KiB
C++
//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
using AbsFOpLowering = VectorConvertToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
|
|
using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
|
|
using CopySignOpLowering =
|
|
VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
|
|
using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
|
|
using CtPopFOpLowering =
|
|
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
|
|
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
|
|
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
|
|
using FloorOpLowering =
|
|
VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
|
|
using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
|
|
using Log10OpLowering =
|
|
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
|
|
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
|
|
using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
|
|
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
|
|
using RoundEvenOpLowering =
|
|
VectorConvertToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
|
|
using RoundOpLowering =
|
|
VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
|
|
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
|
|
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
|
|
|
|
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
|
|
template <typename MathOp, typename LLVMOp>
|
|
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
|
|
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
|
|
using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto operandType = adaptor.getOperand().getType();
|
|
|
|
if (!operandType || !LLVM::isCompatibleType(operandType))
|
|
return failure();
|
|
|
|
auto loc = op.getLoc();
|
|
auto resultType = op.getResult().getType();
|
|
auto boolZero = rewriter.getBoolAttr(false);
|
|
|
|
if (!operandType.template isa<LLVM::LLVMArrayType>()) {
|
|
LLVM::ConstantOp zero = rewriter.create<LLVM::ConstantOp>(loc, boolZero);
|
|
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
|
|
zero);
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.template dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return failure();
|
|
|
|
return LLVM::detail::handleMultidimensionalVectors(
|
|
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
LLVM::ConstantOp zero =
|
|
rewriter.create<LLVM::ConstantOp>(loc, boolZero);
|
|
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
|
|
zero);
|
|
},
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
using CountLeadingZerosOpLowering =
|
|
IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
|
|
using CountTrailingZerosOpLowering =
|
|
IntOpWithFlagLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
|
|
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
|
|
|
|
// A `expm1` is converted into `exp - 1`.
|
|
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
|
|
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto operandType = adaptor.getOperand().getType();
|
|
|
|
if (!operandType || !LLVM::isCompatibleType(operandType))
|
|
return failure();
|
|
|
|
auto loc = op.getLoc();
|
|
auto resultType = op.getResult().getType();
|
|
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
|
|
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
|
|
|
|
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
|
LLVM::ConstantOp one;
|
|
if (LLVM::isCompatibleVectorType(operandType)) {
|
|
one = rewriter.create<LLVM::ConstantOp>(
|
|
loc, operandType,
|
|
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
|
|
} else {
|
|
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
|
|
}
|
|
auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
|
|
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return rewriter.notifyMatchFailure(op, "expected vector result type");
|
|
|
|
return LLVM::detail::handleMultidimensionalVectors(
|
|
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
auto splatAttr = SplatElementsAttr::get(
|
|
mlir::VectorType::get(
|
|
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
|
|
floatType),
|
|
floatOne);
|
|
auto one =
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
|
|
auto exp =
|
|
rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
|
|
return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
|
|
},
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
// A `log1p` is converted into `log(1 + ...)`.
|
|
struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
|
|
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto operandType = adaptor.getOperand().getType();
|
|
|
|
if (!operandType || !LLVM::isCompatibleType(operandType))
|
|
return rewriter.notifyMatchFailure(op, "unsupported operand type");
|
|
|
|
auto loc = op.getLoc();
|
|
auto resultType = op.getResult().getType();
|
|
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
|
|
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
|
|
|
|
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
|
LLVM::ConstantOp one =
|
|
LLVM::isCompatibleVectorType(operandType)
|
|
? rewriter.create<LLVM::ConstantOp>(
|
|
loc, operandType,
|
|
SplatElementsAttr::get(resultType.cast<ShapedType>(),
|
|
floatOne))
|
|
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
|
|
|
|
auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
|
|
adaptor.getOperand());
|
|
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return rewriter.notifyMatchFailure(op, "expected vector result type");
|
|
|
|
return LLVM::detail::handleMultidimensionalVectors(
|
|
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
auto splatAttr = SplatElementsAttr::get(
|
|
mlir::VectorType::get(
|
|
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
|
|
floatType),
|
|
floatOne);
|
|
auto one =
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
|
|
auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
|
|
operands[0]);
|
|
return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
|
|
},
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
// A `rsqrt` is converted into `1 / sqrt`.
|
|
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
|
|
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto operandType = adaptor.getOperand().getType();
|
|
|
|
if (!operandType || !LLVM::isCompatibleType(operandType))
|
|
return failure();
|
|
|
|
auto loc = op.getLoc();
|
|
auto resultType = op.getResult().getType();
|
|
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
|
|
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
|
|
|
|
if (!operandType.isa<LLVM::LLVMArrayType>()) {
|
|
LLVM::ConstantOp one;
|
|
if (LLVM::isCompatibleVectorType(operandType)) {
|
|
one = rewriter.create<LLVM::ConstantOp>(
|
|
loc, operandType,
|
|
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
|
|
} else {
|
|
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
|
|
}
|
|
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
|
|
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
|
|
return success();
|
|
}
|
|
|
|
auto vectorType = resultType.dyn_cast<VectorType>();
|
|
if (!vectorType)
|
|
return failure();
|
|
|
|
return LLVM::detail::handleMultidimensionalVectors(
|
|
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
|
|
[&](Type llvm1DVectorTy, ValueRange operands) {
|
|
auto splatAttr = SplatElementsAttr::get(
|
|
mlir::VectorType::get(
|
|
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
|
|
floatType),
|
|
floatOne);
|
|
auto one =
|
|
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
|
|
auto sqrt =
|
|
rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
|
|
return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
|
|
},
|
|
rewriter);
|
|
}
|
|
};
|
|
|
|
struct ConvertMathToLLVMPass
|
|
: public ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
|
|
ConvertMathToLLVMPass() = default;
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
LLVMTypeConverter converter(&getContext());
|
|
populateMathToLLVMConversionPatterns(converter, patterns);
|
|
LLVMConversionTarget target(getContext());
|
|
if (failed(applyPartialConversion(getOperation(), target,
|
|
std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
|
RewritePatternSet &patterns) {
|
|
// clang-format off
|
|
patterns.add<
|
|
AbsFOpLowering,
|
|
AbsIOpLowering,
|
|
CeilOpLowering,
|
|
CopySignOpLowering,
|
|
CosOpLowering,
|
|
CountLeadingZerosOpLowering,
|
|
CountTrailingZerosOpLowering,
|
|
CtPopFOpLowering,
|
|
Exp2OpLowering,
|
|
ExpM1OpLowering,
|
|
ExpOpLowering,
|
|
FloorOpLowering,
|
|
FmaOpLowering,
|
|
Log10OpLowering,
|
|
Log1pOpLowering,
|
|
Log2OpLowering,
|
|
LogOpLowering,
|
|
PowFOpLowering,
|
|
RoundEvenOpLowering,
|
|
RoundOpLowering,
|
|
RsqrtOpLowering,
|
|
SinOpLowering,
|
|
SqrtOpLowering
|
|
>(converter);
|
|
// clang-format on
|
|
}
|
|
|
|
std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
|
|
return std::make_unique<ConvertMathToLLVMPass>();
|
|
}
|