//===- ArithToLLVM.cpp - Arithmetic to LLVM dialect conversion -------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" namespace mlir { #define GEN_PASS_DEF_ARITHTOLLVMCONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { // Map arithmetic fastmath enum values to LLVMIR enum values. static LLVM::FastmathFlags convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) { LLVM::FastmathFlags llvmFMF{}; const std::pair flags[] = { {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; for (auto fmfMap : flags) { if (bitEnumContainsAny(arithFMF, fmfMap.first)) llvmFMF = llvmFMF | fmfMap.second; } return llvmFMF; } // Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. static LLVM::FastmathFlagsAttr convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) { arith::FastMathFlags arithFMF = fmfAttr.getValue(); return LLVM::FastmathFlagsAttr::get( fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF)); } // Attribute converter that populates a NamedAttrList by removing the fastmath // attribute from the source operation attributes, and replacing it with an // equivalent LLVM fastmath attribute. template class AttrConvertFastMath { public: AttrConvertFastMath(SourceOp srcOp) { // Copy the source attributes. convertedAttr = NamedAttrList{srcOp->getAttrs()}; // Get the name of the arith fastmath attribute. llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); // Remove the source fastmath attribute. auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName) .template dyn_cast_or_null(); if (arithFMFAttr) { llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr)); } } ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } private: NamedAttrList convertedAttr; }; // Attribute converter that populates a NamedAttrList by removing the fastmath // attribute from the source operation attributes. This may be useful for // target operations that do not require the fastmath attribute, or for targets // that do not yet support the LLVM fastmath attribute. template class AttrDropFastMath { public: AttrDropFastMath(SourceOp srcOp) { // Copy the source attributes. convertedAttr = NamedAttrList{srcOp->getAttrs()}; // Get the name of the arith fastmath attribute. llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); // Remove the source fastmath attribute. convertedAttr.erase(arithFMFAttrName); } ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } private: NamedAttrList convertedAttr; }; //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// using AddFOpLowering = VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndIOpLowering = VectorConvertToLLVMPattern; using BitcastOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; using DivSIOpLowering = VectorConvertToLLVMPattern; using DivUIOpLowering = VectorConvertToLLVMPattern; using ExtFOpLowering = VectorConvertToLLVMPattern; using ExtSIOpLowering = VectorConvertToLLVMPattern; using ExtUIOpLowering = VectorConvertToLLVMPattern; using FPToSIOpLowering = VectorConvertToLLVMPattern; using FPToUIOpLowering = VectorConvertToLLVMPattern; // TODO: Add LLVM intrinsic support for fastmath using MaxFOpLowering = VectorConvertToLLVMPattern; using MaxSIOpLowering = VectorConvertToLLVMPattern; using MaxUIOpLowering = VectorConvertToLLVMPattern; // TODO: Add LLVM intrinsic support for fastmath using MinFOpLowering = VectorConvertToLLVMPattern; using MinSIOpLowering = VectorConvertToLLVMPattern; using MinUIOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using OrIOpLowering = VectorConvertToLLVMPattern; // TODO: Add LLVM intrinsic support for fastmath using RemFOpLowering = VectorConvertToLLVMPattern; using RemSIOpLowering = VectorConvertToLLVMPattern; using RemUIOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = VectorConvertToLLVMPattern; using ShLIOpLowering = VectorConvertToLLVMPattern; using ShRSIOpLowering = VectorConvertToLLVMPattern; using ShRUIOpLowering = VectorConvertToLLVMPattern; using SIToFPOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using TruncFOpLowering = VectorConvertToLLVMPattern; using TruncIOpLowering = VectorConvertToLLVMPattern; using UIToFPOpLowering = VectorConvertToLLVMPattern; using XOrIOpLowering = VectorConvertToLLVMPattern; //===----------------------------------------------------------------------===// // Op Lowering Patterns //===----------------------------------------------------------------------===// /// Directly lower to LLVM op. struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// The lowering of index_cast becomes an integer conversion since index /// becomes an integer. If the bit width of the source and target integer /// types is the same, just erase the cast. If the target type is wider, /// sign-extend the value, otherwise truncate it. template struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; using IndexCastOpSILowering = IndexCastOpLowering; using IndexCastOpUILowering = IndexCastOpLowering; struct AddUICarryOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::AddUICarryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace //===----------------------------------------------------------------------===// // ConstantOpLowering //===----------------------------------------------------------------------===// LogicalResult ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), op->getAttrs(), *getTypeConverter(), rewriter); } //===----------------------------------------------------------------------===// // IndexCastOpLowering //===----------------------------------------------------------------------===// template LogicalResult IndexCastOpLowering::matchAndRewrite( OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { Type resultType = op.getResult().getType(); Type targetElementType = this->typeConverter->convertType(getElementTypeOrSelf(resultType)); Type sourceElementType = this->typeConverter->convertType(getElementTypeOrSelf(op.getIn())); unsigned targetBits = targetElementType.getIntOrFloatBitWidth(); unsigned sourceBits = sourceElementType.getIntOrFloatBitWidth(); if (targetBits == sourceBits) { rewriter.replaceOp(op, adaptor.getIn()); return success(); } // Handle the scalar and 1D vector cases. Type operandType = adaptor.getIn().getType(); if (!operandType.isa()) { Type targetType = this->typeConverter->convertType(resultType); if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); else rewriter.replaceOpWithNewOp(op, targetType, adaptor.getIn()); return success(); } if (!resultType.isa()) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *(this->getTypeConverter()), [&](Type llvm1DVectorTy, ValueRange operands) -> Value { typename OpTy::Adaptor adaptor(operands); if (targetBits < sourceBits) { return rewriter.create(op.getLoc(), llvm1DVectorTy, adaptor.getIn()); } return rewriter.create(op.getLoc(), llvm1DVectorTy, adaptor.getIn()); }, rewriter); } //===----------------------------------------------------------------------===// // AddUICarryOpLowering //===----------------------------------------------------------------------===// LogicalResult AddUICarryOpLowering::matchAndRewrite( arith::AddUICarryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = adaptor.getLhs().getType(); Type sumResultType = op.getSum().getType(); Type carryResultType = op.getCarry().getType(); if (!LLVM::isCompatibleType(operandType)) return failure(); MLIRContext *ctx = rewriter.getContext(); Location loc = op.getLoc(); // Handle the scalar and 1D vector cases. if (!operandType.isa()) { Type newCarryType = typeConverter->convertType(carryResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newCarryType}); Value addOverflow = rewriter.create( loc, structType, adaptor.getLhs(), adaptor.getRhs()); Value sumExtracted = rewriter.create(loc, addOverflow, 0); Value carryExtracted = rewriter.create(loc, addOverflow, 1); rewriter.replaceOp(op, {sumExtracted, carryExtracted}); return success(); } if (!sumResultType.isa()) return rewriter.notifyMatchFailure(loc, "expected vector result types"); return rewriter.notifyMatchFailure(loc, "ND vector types are not supported yet"); } //===----------------------------------------------------------------------===// // CmpIOpLowering //===----------------------------------------------------------------------===// // Convert arith.cmp predicate into the LLVM dialect CmpPredicate. The two enums // share numerical values so just cast. template static LLVMPredType convertCmpPredicate(PredType pred) { return static_cast(pred); } LogicalResult CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. if (!operandType.isa()) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); return success(); } if (!resultType.isa()) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, rewriter); } //===----------------------------------------------------------------------===// // CmpFOpLowering //===----------------------------------------------------------------------===// LogicalResult CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type operandType = adaptor.getLhs().getType(); Type resultType = op.getResult().getType(); // Handle the scalar and 1D vector cases. if (!operandType.isa()) { rewriter.replaceOpWithNewOp( op, typeConverter->convertType(resultType), convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); return success(); } if (!resultType.isa()) return rewriter.notifyMatchFailure(op, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); return rewriter.create( op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, rewriter); } //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// namespace { struct ArithToLLVMConversionPass : public impl::ArithToLLVMConversionPassBase { using Base::Base; void runOnOperation() override { LLVMConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); LowerToLLVMOptions options(&getContext()); if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); LLVMTypeConverter converter(&getContext(), options); mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern Population //===----------------------------------------------------------------------===// void mlir::arith::populateArithToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { // clang-format off patterns.add< AddFOpLowering, AddIOpLowering, AndIOpLowering, AddUICarryOpLowering, BitcastOpLowering, ConstantOpLowering, CmpFOpLowering, CmpIOpLowering, DivFOpLowering, DivSIOpLowering, DivUIOpLowering, ExtFOpLowering, ExtSIOpLowering, ExtUIOpLowering, FPToSIOpLowering, FPToUIOpLowering, IndexCastOpSILowering, IndexCastOpUILowering, MaxFOpLowering, MaxSIOpLowering, MaxUIOpLowering, MinFOpLowering, MinSIOpLowering, MinUIOpLowering, MulFOpLowering, MulIOpLowering, NegFOpLowering, OrIOpLowering, RemFOpLowering, RemSIOpLowering, RemUIOpLowering, SelectOpLowering, ShLIOpLowering, ShRSIOpLowering, ShRUIOpLowering, SIToFPOpLowering, SubFOpLowering, SubIOpLowering, TruncFOpLowering, TruncIOpLowering, UIToFPOpLowering, XOrIOpLowering >(converter); // clang-format on }