//===- MathOps.cpp - MLIR operations for math implementation --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/Builders.h" using namespace mlir; using namespace mlir::math; //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" //===----------------------------------------------------------------------===// // AbsOp folder //===----------------------------------------------------------------------===// OpFoldResult math::AbsOp::fold(ArrayRef operands) { return constFoldUnaryOp(operands, [](const APFloat &a) { const APFloat &result(a); return abs(result); }); } //===----------------------------------------------------------------------===// // CeilOp folder //===----------------------------------------------------------------------===// OpFoldResult math::CeilOp::fold(ArrayRef operands) { return constFoldUnaryOp(operands, [](const APFloat &a) { APFloat result(a); result.roundToIntegral(llvm::RoundingMode::TowardPositive); return result; }); } //===----------------------------------------------------------------------===// // CopySignOp folder //===----------------------------------------------------------------------===// OpFoldResult math::CopySignOp::fold(ArrayRef operands) { return constFoldBinaryOp(operands, [](const APFloat &a, const APFloat &b) { APFloat result(a); result.copySign(b); return result; }); } //===----------------------------------------------------------------------===// // CountLeadingZerosOp folder //===----------------------------------------------------------------------===// OpFoldResult math::CountLeadingZerosOp::fold(ArrayRef operands) { return constFoldUnaryOp(operands, [](const APInt &a) { return APInt(a.getBitWidth(), a.countLeadingZeros()); }); } //===----------------------------------------------------------------------===// // CountTrailingZerosOp folder //===----------------------------------------------------------------------===// OpFoldResult math::CountTrailingZerosOp::fold(ArrayRef operands) { return constFoldUnaryOp(operands, [](const APInt &a) { return APInt(a.getBitWidth(), a.countTrailingZeros()); }); } //===----------------------------------------------------------------------===// // CtPopOp folder //===----------------------------------------------------------------------===// OpFoldResult math::CtPopOp::fold(ArrayRef operands) { return constFoldUnaryOp(operands, [](const APInt &a) { return APInt(a.getBitWidth(), a.countPopulation()); }); } //===----------------------------------------------------------------------===// // Log2Op folder //===----------------------------------------------------------------------===// OpFoldResult math::Log2Op::fold(ArrayRef operands) { return constFoldUnaryOpConditional( operands, [](const APFloat &a) -> Optional { if (a.isNegative()) return {}; if (a.getSizeInBits(a.getSemantics()) == 64) return APFloat(log2(a.convertToDouble())); if (a.getSizeInBits(a.getSemantics()) == 32) return APFloat(log2f(a.convertToFloat())); return {}; }); } //===----------------------------------------------------------------------===// // Log10Op folder //===----------------------------------------------------------------------===// OpFoldResult math::Log10Op::fold(ArrayRef operands) { return constFoldUnaryOpConditional( operands, [](const APFloat &a) -> Optional { if (a.isNegative()) return {}; switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(log10(a.convertToDouble())); case 32: return APFloat(log10f(a.convertToFloat())); default: return {}; } }); } //===----------------------------------------------------------------------===// // Log1pOp folder //===----------------------------------------------------------------------===// OpFoldResult math::Log1pOp::fold(ArrayRef operands) { return constFoldUnaryOpConditional( operands, [](const APFloat &a) -> Optional { switch (a.getSizeInBits(a.getSemantics())) { case 64: if ((a + APFloat(1.0)).isNegative()) return {}; return APFloat(log1p(a.convertToDouble())); case 32: if ((a + APFloat(1.0f)).isNegative()) return {}; return APFloat(log1pf(a.convertToFloat())); default: return {}; } }); } //===----------------------------------------------------------------------===// // PowFOp folder //===----------------------------------------------------------------------===// OpFoldResult math::PowFOp::fold(ArrayRef operands) { return constFoldBinaryOpConditional( operands, [](const APFloat &a, const APFloat &b) -> Optional { if (a.getSizeInBits(a.getSemantics()) == 64 && b.getSizeInBits(b.getSemantics()) == 64) return APFloat(pow(a.convertToDouble(), b.convertToDouble())); if (a.getSizeInBits(a.getSemantics()) == 32 && b.getSizeInBits(b.getSemantics()) == 32) return APFloat(powf(a.convertToFloat(), b.convertToFloat())); return {}; }); } //===----------------------------------------------------------------------===// // SqrtOp folder //===----------------------------------------------------------------------===// OpFoldResult math::SqrtOp::fold(ArrayRef operands) { return constFoldUnaryOpConditional( operands, [](const APFloat &a) -> Optional { if (a.isNegative()) return {}; switch (a.getSizeInBits(a.getSemantics())) { case 64: return APFloat(sqrt(a.convertToDouble())); case 32: return APFloat(sqrtf(a.convertToFloat())); default: return {}; } }); } /// Materialize an integer or floating point constant. Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create(loc, value, type); }