//===- MathToLLVM.cpp - Math to LLVM dialect conversion -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "../PassDetail.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/TypeUtilities.h" using namespace mlir; namespace { using AbsFOpLowering = VectorConvertToLLVMPattern; using AbsIOpLowering = VectorConvertToLLVMPattern; using CeilOpLowering = VectorConvertToLLVMPattern; using CopySignOpLowering = VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; using CtPopFOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; using ExpOpLowering = VectorConvertToLLVMPattern; using FloorOpLowering = VectorConvertToLLVMPattern; using FmaOpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; using LogOpLowering = VectorConvertToLLVMPattern; using PowFOpLowering = VectorConvertToLLVMPattern; using RoundOpLowering = VectorConvertToLLVMPattern; using SinOpLowering = VectorConvertToLLVMPattern; using SqrtOpLowering = VectorConvertToLLVMPattern; // A `CtLz/CtTz(a)` is converted into `CtLz/CtTz(a, false)`. template struct CountOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = CountOpLowering; LogicalResult matchAndRewrite(MathOp op, typename MathOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operandType = adaptor.getOperand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); auto boolZero = rewriter.getBoolAttr(false); if (!operandType.template isa()) { LLVM::ConstantOp zero = rewriter.create(loc, boolZero); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getOperand(), zero); return success(); } auto vectorType = resultType.template dyn_cast(); if (!vectorType) return failure(); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { LLVM::ConstantOp zero = rewriter.create(loc, boolZero); return rewriter.create(loc, llvm1DVectorTy, operands[0], zero); }, rewriter); } }; using CountLeadingZerosOpLowering = CountOpLowering; using CountTrailingZerosOpLowering = CountOpLowering; // A `expm1` is converted into `exp - 1`. struct ExpM1OpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(math::ExpM1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operandType = adaptor.getOperand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); if (!operandType.isa()) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(resultType.cast(), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } auto exp = rewriter.create(loc, adaptor.getOperand()); rewriter.replaceOpWithNewOp(op, operandType, exp, one); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, floatType), floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); auto exp = rewriter.create(loc, llvm1DVectorTy, operands[0]); return rewriter.create(loc, llvm1DVectorTy, exp, one); }, rewriter); } }; // A `log1p` is converted into `log(1 + ...)`. struct Log1pOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(math::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operandType = adaptor.getOperand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return rewriter.notifyMatchFailure(op, "unsupported operand type"); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); if (!operandType.isa()) { LLVM::ConstantOp one = LLVM::isCompatibleVectorType(operandType) ? rewriter.create( loc, operandType, SplatElementsAttr::get(resultType.cast(), floatOne)) : rewriter.create(loc, operandType, floatOne); auto add = rewriter.create(loc, operandType, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp(op, operandType, add); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, floatType), floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); auto add = rewriter.create(loc, llvm1DVectorTy, one, operands[0]); return rewriter.create(loc, llvm1DVectorTy, add); }, rewriter); } }; // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(math::RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto operandType = adaptor.getOperand().getType(); if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); auto loc = op.getLoc(); auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); if (!operandType.isa()) { LLVM::ConstantOp one; if (LLVM::isCompatibleVectorType(operandType)) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(resultType.cast(), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } auto sqrt = rewriter.create(loc, adaptor.getOperand()); rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return failure(); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( {LLVM::getVectorNumElements(llvm1DVectorTy).getFixedValue()}, floatType), floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); auto sqrt = rewriter.create(loc, llvm1DVectorTy, operands[0]); return rewriter.create(loc, llvm1DVectorTy, one, sqrt); }, rewriter); } }; struct ConvertMathToLLVMPass : public ConvertMathToLLVMBase { ConvertMathToLLVMPass() = default; void runOnOperation() override { RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); populateMathToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< AbsFOpLowering, AbsIOpLowering, CeilOpLowering, CopySignOpLowering, CosOpLowering, CountLeadingZerosOpLowering, CountTrailingZerosOpLowering, CtPopFOpLowering, Exp2OpLowering, ExpM1OpLowering, ExpOpLowering, FloorOpLowering, FmaOpLowering, Log10OpLowering, Log1pOpLowering, Log2OpLowering, LogOpLowering, PowFOpLowering, RoundOpLowering, RsqrtOpLowering, SinOpLowering, SqrtOpLowering >(converter); // clang-format on } std::unique_ptr mlir::createConvertMathToLLVMPass() { return std::make_unique(); }