jacquesguan 9d0b90e933 [mlir][Math] Add TruncOp.
This patch adds TruncOp for Math, it returns the operand rounded to the nearest integer not larger in magnitude than the operand. And this patch also adds the correspond llvm intrinsic op.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D133342
2022-09-09 10:01:28 +08:00

311 lines
12 KiB
C++

//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOLLVM
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
using AbsFOpLowering = VectorConvertToLLVMPattern<math::AbsFOp, LLVM::FAbsOp>;
using CeilOpLowering = VectorConvertToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
using CopySignOpLowering =
VectorConvertToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
using CosOpLowering = VectorConvertToLLVMPattern<math::CosOp, LLVM::CosOp>;
using CtPopFOpLowering =
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
using FloorOpLowering =
VectorConvertToLLVMPattern<math::FloorOp, LLVM::FFloorOp>;
using FmaOpLowering = VectorConvertToLLVMPattern<math::FmaOp, LLVM::FMAOp>;
using Log10OpLowering =
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
using LogOpLowering = VectorConvertToLLVMPattern<math::LogOp, LLVM::LogOp>;
using PowFOpLowering = VectorConvertToLLVMPattern<math::PowFOp, LLVM::PowOp>;
using RoundEvenOpLowering =
VectorConvertToLLVMPattern<math::RoundEvenOp, LLVM::RoundEvenOp>;
using RoundOpLowering =
VectorConvertToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
using SinOpLowering = VectorConvertToLLVMPattern<math::SinOp, LLVM::SinOp>;
using SqrtOpLowering = VectorConvertToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
using FTruncOpLowering =
VectorConvertToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
template <typename MathOp, typename LLVMOp>
struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
using ConvertOpToLLVMPattern<MathOp>::ConvertOpToLLVMPattern;
using Super = IntOpWithFlagLowering<MathOp, LLVMOp>;
LogicalResult
matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto boolZero = rewriter.getBoolAttr(false);
if (!operandType.template isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp zero = rewriter.create<LLVM::ConstantOp>(loc, boolZero);
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
zero);
return success();
}
auto vectorType = resultType.template dyn_cast<VectorType>();
if (!vectorType)
return failure();
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
LLVM::ConstantOp zero =
rewriter.create<LLVM::ConstantOp>(loc, boolZero);
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
zero);
},
rewriter);
}
};
using CountLeadingZerosOpLowering =
IntOpWithFlagLowering<math::CountLeadingZerosOp, LLVM::CountLeadingZerosOp>;
using CountTrailingZerosOpLowering =
IntOpWithFlagLowering<math::CountTrailingZerosOp, LLVM::CountTrailingZerosOp>;
using AbsIOpLowering = IntOpWithFlagLowering<math::AbsIOp, LLVM::AbsOp>;
// A `expm1` is converted into `exp - 1`.
struct ExpM1OpLowering : public ConvertOpToLLVMPattern<math::ExpM1Op> {
using ConvertOpToLLVMPattern<math::ExpM1Op>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(operandType)) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
auto exp = rewriter.create<LLVM::ExpOp>(loc, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FSubOp>(op, operandType, exp, one);
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
auto exp =
rewriter.create<LLVM::ExpOp>(loc, llvm1DVectorTy, operands[0]);
return rewriter.create<LLVM::FSubOp>(loc, llvm1DVectorTy, exp, one);
},
rewriter);
}
};
// A `log1p` is converted into `log(1 + ...)`.
struct Log1pOpLowering : public ConvertOpToLLVMPattern<math::Log1pOp> {
using ConvertOpToLLVMPattern<math::Log1pOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return rewriter.notifyMatchFailure(op, "unsupported operand type");
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one =
LLVM::isCompatibleVectorType(operandType)
? rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(resultType.cast<ShapedType>(),
floatOne))
: rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
auto add = rewriter.create<LLVM::FAddOp>(loc, operandType, one,
adaptor.getOperand());
rewriter.replaceOpWithNewOp<LLVM::LogOp>(op, operandType, add);
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return rewriter.notifyMatchFailure(op, "expected vector result type");
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
auto add = rewriter.create<LLVM::FAddOp>(loc, llvm1DVectorTy, one,
operands[0]);
return rewriter.create<LLVM::LogOp>(loc, llvm1DVectorTy, add);
},
rewriter);
}
};
// A `rsqrt` is converted into `1 / sqrt`.
struct RsqrtOpLowering : public ConvertOpToLLVMPattern<math::RsqrtOp> {
using ConvertOpToLLVMPattern<math::RsqrtOp>::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto operandType = adaptor.getOperand().getType();
if (!operandType || !LLVM::isCompatibleType(operandType))
return failure();
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
auto floatType = getElementTypeOrSelf(resultType).cast<FloatType>();
auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
if (!operandType.isa<LLVM::LLVMArrayType>()) {
LLVM::ConstantOp one;
if (LLVM::isCompatibleVectorType(operandType)) {
one = rewriter.create<LLVM::ConstantOp>(
loc, operandType,
SplatElementsAttr::get(resultType.cast<ShapedType>(), floatOne));
} else {
one = rewriter.create<LLVM::ConstantOp>(loc, operandType, floatOne);
}
auto sqrt = rewriter.create<LLVM::SqrtOp>(loc, adaptor.getOperand());
rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, operandType, one, sqrt);
return success();
}
auto vectorType = resultType.dyn_cast<VectorType>();
if (!vectorType)
return failure();
return LLVM::detail::handleMultidimensionalVectors(
op.getOperation(), adaptor.getOperands(), *getTypeConverter(),
[&](Type llvm1DVectorTy, ValueRange operands) {
auto splatAttr = SplatElementsAttr::get(
mlir::VectorType::get(
{LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()},
floatType),
floatOne);
auto one =
rewriter.create<LLVM::ConstantOp>(loc, llvm1DVectorTy, splatAttr);
auto sqrt =
rewriter.create<LLVM::SqrtOp>(loc, llvm1DVectorTy, operands[0]);
return rewriter.create<LLVM::FDivOp>(loc, llvm1DVectorTy, one, sqrt);
},
rewriter);
}
};
struct ConvertMathToLLVMPass
: public impl::ConvertMathToLLVMBase<ConvertMathToLLVMPass> {
ConvertMathToLLVMPass() = default;
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateMathToLLVMConversionPatterns(converter, patterns);
LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// clang-format off
patterns.add<
AbsFOpLowering,
AbsIOpLowering,
CeilOpLowering,
CopySignOpLowering,
CosOpLowering,
CountLeadingZerosOpLowering,
CountTrailingZerosOpLowering,
CtPopFOpLowering,
Exp2OpLowering,
ExpM1OpLowering,
ExpOpLowering,
FloorOpLowering,
FmaOpLowering,
Log10OpLowering,
Log1pOpLowering,
Log2OpLowering,
LogOpLowering,
PowFOpLowering,
RoundEvenOpLowering,
RoundOpLowering,
RsqrtOpLowering,
SinOpLowering,
SqrtOpLowering,
FTruncOpLowering
>(converter);
// clang-format on
}
std::unique_ptr<Pass> mlir::createConvertMathToLLVMPass() {
return std::make_unique<ConvertMathToLLVMPass>();
}