
These are identified by misc-include-cleaner. I've filtered out those that break builds. Also, I'm staying away from llvm-config.h, config.h, and Compiler.h, which likely cause platform- or compiler-specific build failures.
1104 lines
48 KiB
C++
1104 lines
48 KiB
C++
//===- ComplexToStandard.cpp - conversion from Complex to Standard dialect ===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"
|
|
|
|
#include "mlir/Conversion/ComplexCommon/DivisionConverter.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include <type_traits>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTCOMPLEXTOSTANDARDPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
enum class AbsFn { abs, sqrt, rsqrt };
|
|
|
|
// Returns the absolute value, its square root or its reciprocal square root.
|
|
Value computeAbs(Value real, Value imag, arith::FastMathFlags fmf,
|
|
ImplicitLocOpBuilder &b, AbsFn fn = AbsFn::abs) {
|
|
Value one = arith::ConstantOp::create(b, real.getType(),
|
|
b.getFloatAttr(real.getType(), 1.0));
|
|
|
|
Value absReal = math::AbsFOp::create(b, real, fmf);
|
|
Value absImag = math::AbsFOp::create(b, imag, fmf);
|
|
|
|
Value max = arith::MaximumFOp::create(b, absReal, absImag, fmf);
|
|
Value min = arith::MinimumFOp::create(b, absReal, absImag, fmf);
|
|
|
|
// The lowering below requires NaNs and infinities to work correctly.
|
|
arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
|
|
fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
|
|
Value ratio = arith::DivFOp::create(b, min, max, fmfWithNaNInf);
|
|
Value ratioSq = arith::MulFOp::create(b, ratio, ratio, fmfWithNaNInf);
|
|
Value ratioSqPlusOne = arith::AddFOp::create(b, ratioSq, one, fmfWithNaNInf);
|
|
Value result;
|
|
|
|
if (fn == AbsFn::rsqrt) {
|
|
ratioSqPlusOne = math::RsqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
|
|
min = math::RsqrtOp::create(b, min, fmfWithNaNInf);
|
|
max = math::RsqrtOp::create(b, max, fmfWithNaNInf);
|
|
}
|
|
|
|
if (fn == AbsFn::sqrt) {
|
|
Value quarter = arith::ConstantOp::create(
|
|
b, real.getType(), b.getFloatAttr(real.getType(), 0.25));
|
|
// sqrt(sqrt(a*b)) would avoid the pow, but will overflow more easily.
|
|
Value sqrt = math::SqrtOp::create(b, max, fmfWithNaNInf);
|
|
Value p025 =
|
|
math::PowFOp::create(b, ratioSqPlusOne, quarter, fmfWithNaNInf);
|
|
result = arith::MulFOp::create(b, sqrt, p025, fmfWithNaNInf);
|
|
} else {
|
|
Value sqrt = math::SqrtOp::create(b, ratioSqPlusOne, fmfWithNaNInf);
|
|
result = arith::MulFOp::create(b, max, sqrt, fmfWithNaNInf);
|
|
}
|
|
|
|
Value isNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, result,
|
|
result, fmfWithNaNInf);
|
|
return arith::SelectOp::create(b, isNaN, min, result);
|
|
}
|
|
|
|
struct AbsOpConversion : public OpConversionPattern<complex::AbsOp> {
|
|
using OpConversionPattern<complex::AbsOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
|
|
|
|
Value real = complex::ReOp::create(b, adaptor.getComplex());
|
|
Value imag = complex::ImOp::create(b, adaptor.getComplex());
|
|
rewriter.replaceOp(op, computeAbs(real, imag, fmf, b));
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// atan2(y,x) = -i * log((x + i * y)/sqrt(x**2+y**2))
|
|
struct Atan2OpConversion : public OpConversionPattern<complex::Atan2Op> {
|
|
using OpConversionPattern<complex::Atan2Op>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::Atan2Op op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
auto type = cast<ComplexType>(op.getType());
|
|
Type elementType = type.getElementType();
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
|
|
Value lhs = adaptor.getLhs();
|
|
Value rhs = adaptor.getRhs();
|
|
|
|
Value rhsSquared = complex::MulOp::create(b, type, rhs, rhs, fmf);
|
|
Value lhsSquared = complex::MulOp::create(b, type, lhs, lhs, fmf);
|
|
Value rhsSquaredPlusLhsSquared =
|
|
complex::AddOp::create(b, type, rhsSquared, lhsSquared, fmf);
|
|
Value sqrtOfRhsSquaredPlusLhsSquared =
|
|
complex::SqrtOp::create(b, type, rhsSquaredPlusLhsSquared, fmf);
|
|
|
|
Value zero =
|
|
arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
|
|
Value one = arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, 1));
|
|
Value i = complex::CreateOp::create(b, type, zero, one);
|
|
Value iTimesLhs = complex::MulOp::create(b, i, lhs, fmf);
|
|
Value rhsPlusILhs = complex::AddOp::create(b, rhs, iTimesLhs, fmf);
|
|
|
|
Value divResult = complex::DivOp::create(
|
|
b, rhsPlusILhs, sqrtOfRhsSquaredPlusLhsSquared, fmf);
|
|
Value logResult = complex::LogOp::create(b, divResult, fmf);
|
|
|
|
Value negativeOne = arith::ConstantOp::create(
|
|
b, elementType, b.getFloatAttr(elementType, -1));
|
|
Value negativeI = complex::CreateOp::create(b, type, zero, negativeOne);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::MulOp>(op, negativeI, logResult, fmf);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename ComparisonOp, arith::CmpFPredicate p>
|
|
struct ComparisonOpConversion : public OpConversionPattern<ComparisonOp> {
|
|
using OpConversionPattern<ComparisonOp>::OpConversionPattern;
|
|
using ResultCombiner =
|
|
std::conditional_t<std::is_same<ComparisonOp, complex::EqualOp>::value,
|
|
arith::AndIOp, arith::OrIOp>;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = cast<ComplexType>(adaptor.getLhs().getType()).getElementType();
|
|
|
|
Value realLhs =
|
|
complex::ReOp::create(rewriter, loc, type, adaptor.getLhs());
|
|
Value imagLhs =
|
|
complex::ImOp::create(rewriter, loc, type, adaptor.getLhs());
|
|
Value realRhs =
|
|
complex::ReOp::create(rewriter, loc, type, adaptor.getRhs());
|
|
Value imagRhs =
|
|
complex::ImOp::create(rewriter, loc, type, adaptor.getRhs());
|
|
Value realComparison =
|
|
arith::CmpFOp::create(rewriter, loc, p, realLhs, realRhs);
|
|
Value imagComparison =
|
|
arith::CmpFOp::create(rewriter, loc, p, imagLhs, imagRhs);
|
|
|
|
rewriter.replaceOpWithNewOp<ResultCombiner>(op, realComparison,
|
|
imagComparison);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// Default conversion which applies the BinaryStandardOp separately on the real
|
|
// and imaginary parts. Can for example be used for complex::AddOp and
|
|
// complex::SubOp.
|
|
template <typename BinaryComplexOp, typename BinaryStandardOp>
|
|
struct BinaryComplexOpConversion : public OpConversionPattern<BinaryComplexOp> {
|
|
using OpConversionPattern<BinaryComplexOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = cast<ComplexType>(adaptor.getLhs().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
|
|
Value realLhs = complex::ReOp::create(b, elementType, adaptor.getLhs());
|
|
Value realRhs = complex::ReOp::create(b, elementType, adaptor.getRhs());
|
|
Value resultReal = BinaryStandardOp::create(b, elementType, realLhs,
|
|
realRhs, fmf.getValue());
|
|
Value imagLhs = complex::ImOp::create(b, elementType, adaptor.getLhs());
|
|
Value imagRhs = complex::ImOp::create(b, elementType, adaptor.getRhs());
|
|
Value resultImag = BinaryStandardOp::create(b, elementType, imagLhs,
|
|
imagRhs, fmf.getValue());
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename TrigonometricOp>
|
|
struct TrigonometricOpConversion : public OpConversionPattern<TrigonometricOp> {
|
|
using OpAdaptor = typename OpConversionPattern<TrigonometricOp>::OpAdaptor;
|
|
|
|
using OpConversionPattern<TrigonometricOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(TrigonometricOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
|
|
Value real =
|
|
complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
|
|
|
|
// Trigonometric ops use a set of common building blocks to convert to real
|
|
// ops. Here we create these building blocks and call into an op-specific
|
|
// implementation in the subclass to combine them.
|
|
Value half = arith::ConstantOp::create(
|
|
rewriter, loc, elementType, rewriter.getFloatAttr(elementType, 0.5));
|
|
Value exp = math::ExpOp::create(rewriter, loc, imag, fmf);
|
|
Value scaledExp = arith::MulFOp::create(rewriter, loc, half, exp, fmf);
|
|
Value reciprocalExp = arith::DivFOp::create(rewriter, loc, half, exp, fmf);
|
|
Value sin = math::SinOp::create(rewriter, loc, real, fmf);
|
|
Value cos = math::CosOp::create(rewriter, loc, real, fmf);
|
|
|
|
auto resultPair =
|
|
combine(loc, scaledExp, reciprocalExp, sin, cos, rewriter, fmf);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultPair.first,
|
|
resultPair.second);
|
|
return success();
|
|
}
|
|
|
|
virtual std::pair<Value, Value>
|
|
combine(Location loc, Value scaledExp, Value reciprocalExp, Value sin,
|
|
Value cos, ConversionPatternRewriter &rewriter,
|
|
arith::FastMathFlagsAttr fmf) const = 0;
|
|
};
|
|
|
|
struct CosOpConversion : public TrigonometricOpConversion<complex::CosOp> {
|
|
using TrigonometricOpConversion<complex::CosOp>::TrigonometricOpConversion;
|
|
|
|
std::pair<Value, Value> combine(Location loc, Value scaledExp,
|
|
Value reciprocalExp, Value sin, Value cos,
|
|
ConversionPatternRewriter &rewriter,
|
|
arith::FastMathFlagsAttr fmf) const override {
|
|
// Complex cosine is defined as;
|
|
// cos(x + iy) = 0.5 * (exp(i(x + iy)) + exp(-i(x + iy)))
|
|
// Plugging in:
|
|
// exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
|
|
// exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
|
|
// and defining t := exp(y)
|
|
// We get:
|
|
// Re(cos(x + iy)) = (0.5/t + 0.5*t) * cos x
|
|
// Im(cos(x + iy)) = (0.5/t - 0.5*t) * sin x
|
|
Value sum =
|
|
arith::AddFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
|
|
Value resultReal = arith::MulFOp::create(rewriter, loc, sum, cos, fmf);
|
|
Value diff =
|
|
arith::SubFOp::create(rewriter, loc, reciprocalExp, scaledExp, fmf);
|
|
Value resultImag = arith::MulFOp::create(rewriter, loc, diff, sin, fmf);
|
|
return {resultReal, resultImag};
|
|
}
|
|
};
|
|
|
|
struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
|
|
DivOpConversion(MLIRContext *context, complex::ComplexRangeFlags target)
|
|
: OpConversionPattern<complex::DivOp>(context), complexRange(target) {}
|
|
|
|
using OpConversionPattern<complex::DivOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::DivOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = cast<ComplexType>(adaptor.getLhs().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
|
|
Value lhsReal =
|
|
complex::ReOp::create(rewriter, loc, elementType, adaptor.getLhs());
|
|
Value lhsImag =
|
|
complex::ImOp::create(rewriter, loc, elementType, adaptor.getLhs());
|
|
Value rhsReal =
|
|
complex::ReOp::create(rewriter, loc, elementType, adaptor.getRhs());
|
|
Value rhsImag =
|
|
complex::ImOp::create(rewriter, loc, elementType, adaptor.getRhs());
|
|
|
|
Value resultReal, resultImag;
|
|
|
|
if (complexRange == complex::ComplexRangeFlags::basic ||
|
|
complexRange == complex::ComplexRangeFlags::none) {
|
|
mlir::complex::convertDivToStandardUsingAlgebraic(
|
|
rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
|
|
&resultImag);
|
|
} else if (complexRange == complex::ComplexRangeFlags::improved) {
|
|
mlir::complex::convertDivToStandardUsingRangeReduction(
|
|
rewriter, loc, lhsReal, lhsImag, rhsReal, rhsImag, fmf, &resultReal,
|
|
&resultImag);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
complex::ComplexRangeFlags complexRange;
|
|
};
|
|
|
|
struct ExpOpConversion : public OpConversionPattern<complex::ExpOp> {
|
|
using OpConversionPattern<complex::ExpOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
|
|
Value real =
|
|
complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
|
|
Value expReal = math::ExpOp::create(rewriter, loc, real, fmf.getValue());
|
|
Value cosImag = math::CosOp::create(rewriter, loc, imag, fmf.getValue());
|
|
Value resultReal =
|
|
arith::MulFOp::create(rewriter, loc, expReal, cosImag, fmf.getValue());
|
|
Value sinImag = math::SinOp::create(rewriter, loc, imag, fmf.getValue());
|
|
Value resultImag =
|
|
arith::MulFOp::create(rewriter, loc, expReal, sinImag, fmf.getValue());
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg,
|
|
ArrayRef<double> coefficients,
|
|
arith::FastMathFlagsAttr fmf) {
|
|
auto argType = mlir::cast<FloatType>(arg.getType());
|
|
Value poly =
|
|
arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[0]));
|
|
for (unsigned i = 1; i < coefficients.size(); ++i) {
|
|
poly = math::FmaOp::create(
|
|
b, poly, arg,
|
|
arith::ConstantOp::create(b, b.getFloatAttr(argType, coefficients[i])),
|
|
fmf);
|
|
}
|
|
return poly;
|
|
}
|
|
|
|
struct Expm1OpConversion : public OpConversionPattern<complex::Expm1Op> {
|
|
using OpConversionPattern<complex::Expm1Op>::OpConversionPattern;
|
|
|
|
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
|
|
// [handle inaccuracies when a and/or b are small]
|
|
// = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i
|
|
// = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i
|
|
LogicalResult
|
|
matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = op.getType();
|
|
auto elemType = mlir::cast<FloatType>(type.getElementType());
|
|
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
Value real = complex::ReOp::create(b, adaptor.getComplex());
|
|
Value imag = complex::ImOp::create(b, adaptor.getComplex());
|
|
|
|
Value zero = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 0.0));
|
|
Value one = arith::ConstantOp::create(b, b.getFloatAttr(elemType, 1.0));
|
|
|
|
Value expm1Real = math::ExpM1Op::create(b, real, fmf);
|
|
Value expReal = arith::AddFOp::create(b, expm1Real, one, fmf);
|
|
|
|
Value sinImag = math::SinOp::create(b, imag, fmf);
|
|
Value cosm1Imag = emitCosm1(imag, fmf, b);
|
|
Value cosImag = arith::AddFOp::create(b, cosm1Imag, one, fmf);
|
|
|
|
Value realResult = arith::AddFOp::create(
|
|
b, arith::MulFOp::create(b, expm1Real, cosImag, fmf), cosm1Imag, fmf);
|
|
|
|
Value imagIsZero = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag,
|
|
zero, fmf.getValue());
|
|
Value imagResult = arith::SelectOp::create(
|
|
b, imagIsZero, zero, arith::MulFOp::create(b, expReal, sinImag, fmf));
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, realResult,
|
|
imagResult);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf,
|
|
ImplicitLocOpBuilder &b) const {
|
|
auto argType = mlir::cast<FloatType>(arg.getType());
|
|
auto negHalf = arith::ConstantOp::create(b, b.getFloatAttr(argType, -0.5));
|
|
auto negOne = arith::ConstantOp::create(b, b.getFloatAttr(argType, -1.0));
|
|
|
|
// Algorithm copied from cephes cosm1.
|
|
SmallVector<double, 7> kCoeffs{
|
|
4.7377507964246204691685E-14, -1.1470284843425359765671E-11,
|
|
2.0876754287081521758361E-9, -2.7557319214999787979814E-7,
|
|
2.4801587301570552304991E-5, -1.3888888888888872993737E-3,
|
|
4.1666666666666666609054E-2,
|
|
};
|
|
Value cos = math::CosOp::create(b, arg, fmf);
|
|
Value forLargeArg = arith::AddFOp::create(b, cos, negOne, fmf);
|
|
|
|
Value argPow2 = arith::MulFOp::create(b, arg, arg, fmf);
|
|
Value argPow4 = arith::MulFOp::create(b, argPow2, argPow2, fmf);
|
|
Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf);
|
|
|
|
auto forSmallArg =
|
|
arith::AddFOp::create(b, arith::MulFOp::create(b, argPow4, poly, fmf),
|
|
arith::MulFOp::create(b, negHalf, argPow2, fmf));
|
|
|
|
// (pi/4)^2 is approximately 0.61685
|
|
Value piOver4Pow2 =
|
|
arith::ConstantOp::create(b, b.getFloatAttr(argType, 0.61685));
|
|
Value cond = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, argPow2,
|
|
piOver4Pow2, fmf.getValue());
|
|
return arith::SelectOp::create(b, cond, forLargeArg, forSmallArg);
|
|
}
|
|
};
|
|
|
|
struct LogOpConversion : public OpConversionPattern<complex::LogOp> {
|
|
using OpConversionPattern<complex::LogOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::LogOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
Value abs = complex::AbsOp::create(b, elementType, adaptor.getComplex(),
|
|
fmf.getValue());
|
|
Value resultReal = math::LogOp::create(b, elementType, abs, fmf.getValue());
|
|
Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
|
|
Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
|
|
Value resultImag =
|
|
math::Atan2Op::create(b, elementType, imag, real, fmf.getValue());
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct Log1pOpConversion : public OpConversionPattern<complex::Log1pOp> {
|
|
using OpConversionPattern<complex::Log1pOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
Value real = complex::ReOp::create(b, adaptor.getComplex());
|
|
Value imag = complex::ImOp::create(b, adaptor.getComplex());
|
|
|
|
Value half = arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, 0.5));
|
|
Value one = arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, 1));
|
|
Value realPlusOne = arith::AddFOp::create(b, real, one, fmf);
|
|
Value absRealPlusOne = math::AbsFOp::create(b, realPlusOne, fmf);
|
|
Value absImag = math::AbsFOp::create(b, imag, fmf);
|
|
|
|
Value maxAbs = arith::MaximumFOp::create(b, absRealPlusOne, absImag, fmf);
|
|
Value minAbs = arith::MinimumFOp::create(b, absRealPlusOne, absImag, fmf);
|
|
|
|
Value useReal = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT,
|
|
realPlusOne, absImag, fmf);
|
|
Value maxMinusOne = arith::SubFOp::create(b, maxAbs, one, fmf);
|
|
Value maxAbsOfRealPlusOneAndImagMinusOne =
|
|
arith::SelectOp::create(b, useReal, real, maxMinusOne);
|
|
arith::FastMathFlags fmfWithNaNInf = arith::bitEnumClear(
|
|
fmf, arith::FastMathFlags::nnan | arith::FastMathFlags::ninf);
|
|
Value minMaxRatio = arith::DivFOp::create(b, minAbs, maxAbs, fmfWithNaNInf);
|
|
Value logOfMaxAbsOfRealPlusOneAndImag =
|
|
math::Log1pOp::create(b, maxAbsOfRealPlusOneAndImagMinusOne, fmf);
|
|
Value logOfSqrtPart = math::Log1pOp::create(
|
|
b, arith::MulFOp::create(b, minMaxRatio, minMaxRatio, fmfWithNaNInf),
|
|
fmfWithNaNInf);
|
|
Value r = arith::AddFOp::create(
|
|
b, arith::MulFOp::create(b, half, logOfSqrtPart, fmfWithNaNInf),
|
|
logOfMaxAbsOfRealPlusOneAndImag, fmfWithNaNInf);
|
|
Value resultReal = arith::SelectOp::create(
|
|
b,
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, r, r,
|
|
fmfWithNaNInf),
|
|
minAbs, r);
|
|
Value resultImag = math::Atan2Op::create(b, imag, realPlusOne, fmf);
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct MulOpConversion : public OpConversionPattern<complex::MulOp> {
|
|
using OpConversionPattern<complex::MulOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::MulOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
auto type = cast<ComplexType>(adaptor.getLhs().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
auto fmfValue = fmf.getValue();
|
|
Value lhsReal = complex::ReOp::create(b, elementType, adaptor.getLhs());
|
|
Value lhsImag = complex::ImOp::create(b, elementType, adaptor.getLhs());
|
|
Value rhsReal = complex::ReOp::create(b, elementType, adaptor.getRhs());
|
|
Value rhsImag = complex::ImOp::create(b, elementType, adaptor.getRhs());
|
|
Value lhsRealTimesRhsReal =
|
|
arith::MulFOp::create(b, lhsReal, rhsReal, fmfValue);
|
|
Value lhsImagTimesRhsImag =
|
|
arith::MulFOp::create(b, lhsImag, rhsImag, fmfValue);
|
|
Value real = arith::SubFOp::create(b, lhsRealTimesRhsReal,
|
|
lhsImagTimesRhsImag, fmfValue);
|
|
Value lhsImagTimesRhsReal =
|
|
arith::MulFOp::create(b, lhsImag, rhsReal, fmfValue);
|
|
Value lhsRealTimesRhsImag =
|
|
arith::MulFOp::create(b, lhsReal, rhsImag, fmfValue);
|
|
Value imag = arith::AddFOp::create(b, lhsImagTimesRhsReal,
|
|
lhsRealTimesRhsImag, fmfValue);
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, imag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct NegOpConversion : public OpConversionPattern<complex::NegOp> {
|
|
using OpConversionPattern<complex::NegOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::NegOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
|
|
Value real =
|
|
complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
|
|
Value negReal = arith::NegFOp::create(rewriter, loc, real);
|
|
Value negImag = arith::NegFOp::create(rewriter, loc, imag);
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct SinOpConversion : public TrigonometricOpConversion<complex::SinOp> {
|
|
using TrigonometricOpConversion<complex::SinOp>::TrigonometricOpConversion;
|
|
|
|
std::pair<Value, Value> combine(Location loc, Value scaledExp,
|
|
Value reciprocalExp, Value sin, Value cos,
|
|
ConversionPatternRewriter &rewriter,
|
|
arith::FastMathFlagsAttr fmf) const override {
|
|
// Complex sine is defined as;
|
|
// sin(x + iy) = -0.5i * (exp(i(x + iy)) - exp(-i(x + iy)))
|
|
// Plugging in:
|
|
// exp(i(x+iy)) = exp(-y + ix) = exp(-y)(cos(x) + i sin(x))
|
|
// exp(-i(x+iy)) = exp(y + i(-x)) = exp(y)(cos(x) + i (-sin(x)))
|
|
// and defining t := exp(y)
|
|
// We get:
|
|
// Re(sin(x + iy)) = (0.5*t + 0.5/t) * sin x
|
|
// Im(cos(x + iy)) = (0.5*t - 0.5/t) * cos x
|
|
Value sum =
|
|
arith::AddFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
|
|
Value resultReal = arith::MulFOp::create(rewriter, loc, sum, sin, fmf);
|
|
Value diff =
|
|
arith::SubFOp::create(rewriter, loc, scaledExp, reciprocalExp, fmf);
|
|
Value resultImag = arith::MulFOp::create(rewriter, loc, diff, cos, fmf);
|
|
return {resultReal, resultImag};
|
|
}
|
|
};
|
|
|
|
// The algorithm is listed in https://dl.acm.org/doi/pdf/10.1145/363717.363780.
|
|
struct SqrtOpConversion : public OpConversionPattern<complex::SqrtOp> {
|
|
using OpConversionPattern<complex::SqrtOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::SqrtOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
|
|
auto type = cast<ComplexType>(op.getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
|
|
|
|
auto cst = [&](APFloat v) {
|
|
return arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, v));
|
|
};
|
|
const auto &floatSemantics = elementType.getFloatSemantics();
|
|
Value zero = cst(APFloat::getZero(floatSemantics));
|
|
Value half = arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, 0.5));
|
|
|
|
Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
|
|
Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
|
|
Value absSqrt = computeAbs(real, imag, fmf, b, AbsFn::sqrt);
|
|
Value argArg = math::Atan2Op::create(b, imag, real, fmf);
|
|
Value sqrtArg = arith::MulFOp::create(b, argArg, half, fmf);
|
|
Value cos = math::CosOp::create(b, sqrtArg, fmf);
|
|
Value sin = math::SinOp::create(b, sqrtArg, fmf);
|
|
// sin(atan2(0, inf)) = 0, sqrt(abs(inf)) = inf, but we can't multiply
|
|
// 0 * inf.
|
|
Value sinIsZero =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, sin, zero, fmf);
|
|
|
|
Value resultReal = arith::MulFOp::create(b, absSqrt, cos, fmf);
|
|
Value resultImag = arith::SelectOp::create(
|
|
b, sinIsZero, zero, arith::MulFOp::create(b, absSqrt, sin, fmf));
|
|
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
|
|
arith::FastMathFlags::ninf)) {
|
|
Value inf = cst(APFloat::getInf(floatSemantics));
|
|
Value negInf = cst(APFloat::getInf(floatSemantics, true));
|
|
Value nan = cst(APFloat::getNaN(floatSemantics));
|
|
Value absImag = math::AbsFOp::create(b, elementType, imag, fmf);
|
|
|
|
Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
|
|
absImag, inf, fmf);
|
|
Value absImagIsNotInf = arith::CmpFOp::create(
|
|
b, arith::CmpFPredicate::ONE, absImag, inf, fmf);
|
|
Value realIsInf =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, inf, fmf);
|
|
Value realIsNegInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
|
|
real, negInf, fmf);
|
|
|
|
resultReal = arith::SelectOp::create(
|
|
b, arith::AndIOp::create(b, realIsNegInf, absImagIsNotInf), zero,
|
|
resultReal);
|
|
resultReal = arith::SelectOp::create(
|
|
b, arith::OrIOp::create(b, absImagIsInf, realIsInf), inf, resultReal);
|
|
|
|
Value imagSignInf = math::CopySignOp::create(b, inf, imag, fmf);
|
|
resultImag = arith::SelectOp::create(
|
|
b,
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, absSqrt, absSqrt),
|
|
nan, resultImag);
|
|
resultImag = arith::SelectOp::create(
|
|
b, arith::OrIOp::create(b, absImagIsInf, realIsNegInf), imagSignInf,
|
|
resultImag);
|
|
}
|
|
|
|
Value resultIsZero =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, absSqrt, zero, fmf);
|
|
resultReal = arith::SelectOp::create(b, resultIsZero, zero, resultReal);
|
|
resultImag = arith::SelectOp::create(b, resultIsZero, zero, resultImag);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct SignOpConversion : public OpConversionPattern<complex::SignOp> {
|
|
using OpConversionPattern<complex::SignOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::SignOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
|
|
Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
|
|
Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
|
|
Value zero =
|
|
arith::ConstantOp::create(b, elementType, b.getZeroAttr(elementType));
|
|
Value realIsZero =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero);
|
|
Value imagIsZero =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero);
|
|
Value isZero = arith::AndIOp::create(b, realIsZero, imagIsZero);
|
|
auto abs =
|
|
complex::AbsOp::create(b, elementType, adaptor.getComplex(), fmf);
|
|
Value realSign = arith::DivFOp::create(b, real, abs, fmf);
|
|
Value imagSign = arith::DivFOp::create(b, imag, abs, fmf);
|
|
Value sign = complex::CreateOp::create(b, type, realSign, imagSign);
|
|
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isZero,
|
|
adaptor.getComplex(), sign);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename Op>
|
|
struct TanTanhOpConversion : public OpConversionPattern<Op> {
|
|
using OpConversionPattern<Op>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
auto loc = op.getLoc();
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
|
|
const auto &floatSemantics = elementType.getFloatSemantics();
|
|
|
|
Value real =
|
|
complex::ReOp::create(b, loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
complex::ImOp::create(b, loc, elementType, adaptor.getComplex());
|
|
Value negOne = arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, -1.0));
|
|
|
|
if constexpr (std::is_same_v<Op, complex::TanOp>) {
|
|
// tan(x+yi) = -i*tanh(-y + xi)
|
|
std::swap(real, imag);
|
|
real = arith::MulFOp::create(b, real, negOne, fmf);
|
|
}
|
|
|
|
auto cst = [&](APFloat v) {
|
|
return arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, v));
|
|
};
|
|
Value inf = cst(APFloat::getInf(floatSemantics));
|
|
Value four = arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, 4.0));
|
|
Value twoReal = arith::AddFOp::create(b, real, real, fmf);
|
|
Value negTwoReal = arith::MulFOp::create(b, negOne, twoReal, fmf);
|
|
|
|
Value expTwoRealMinusOne = math::ExpM1Op::create(b, twoReal, fmf);
|
|
Value expNegTwoRealMinusOne = math::ExpM1Op::create(b, negTwoReal, fmf);
|
|
Value realNum = arith::SubFOp::create(b, expTwoRealMinusOne,
|
|
expNegTwoRealMinusOne, fmf);
|
|
|
|
Value cosImag = math::CosOp::create(b, imag, fmf);
|
|
Value cosImagSq = arith::MulFOp::create(b, cosImag, cosImag, fmf);
|
|
Value twoCosTwoImagPlusOne = arith::MulFOp::create(b, cosImagSq, four, fmf);
|
|
Value sinImag = math::SinOp::create(b, imag, fmf);
|
|
|
|
Value imagNum = arith::MulFOp::create(
|
|
b, four, arith::MulFOp::create(b, cosImag, sinImag, fmf), fmf);
|
|
|
|
Value expSumMinusTwo = arith::AddFOp::create(b, expTwoRealMinusOne,
|
|
expNegTwoRealMinusOne, fmf);
|
|
Value denom =
|
|
arith::AddFOp::create(b, expSumMinusTwo, twoCosTwoImagPlusOne, fmf);
|
|
|
|
Value isInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
|
|
expSumMinusTwo, inf, fmf);
|
|
Value realLimit = math::CopySignOp::create(b, negOne, real, fmf);
|
|
|
|
Value resultReal = arith::SelectOp::create(
|
|
b, isInf, realLimit, arith::DivFOp::create(b, realNum, denom, fmf));
|
|
Value resultImag = arith::DivFOp::create(b, imagNum, denom, fmf);
|
|
|
|
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
|
|
arith::FastMathFlags::ninf)) {
|
|
Value absReal = math::AbsFOp::create(b, real, fmf);
|
|
Value zero = arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, 0.0));
|
|
Value nan = cst(APFloat::getNaN(floatSemantics));
|
|
|
|
Value absRealIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
|
|
absReal, inf, fmf);
|
|
Value imagIsZero =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
|
|
Value absRealIsNotInf = arith::XOrIOp::create(
|
|
b, absRealIsInf, arith::ConstantIntOp::create(b, true, /*width=*/1));
|
|
|
|
Value imagNumIsNaN = arith::CmpFOp::create(b, arith::CmpFPredicate::UNO,
|
|
imagNum, imagNum, fmf);
|
|
Value resultRealIsNaN =
|
|
arith::AndIOp::create(b, imagNumIsNaN, absRealIsNotInf);
|
|
Value resultImagIsZero = arith::OrIOp::create(
|
|
b, imagIsZero, arith::AndIOp::create(b, absRealIsInf, imagNumIsNaN));
|
|
|
|
resultReal = arith::SelectOp::create(b, resultRealIsNaN, nan, resultReal);
|
|
resultImag =
|
|
arith::SelectOp::create(b, resultImagIsZero, zero, resultImag);
|
|
}
|
|
|
|
if constexpr (std::is_same_v<Op, complex::TanOp>) {
|
|
// tan(x+yi) = -i*tanh(-y + xi)
|
|
std::swap(resultReal, resultImag);
|
|
resultImag = arith::MulFOp::create(b, resultImag, negOne, fmf);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct ConjOpConversion : public OpConversionPattern<complex::ConjOp> {
|
|
using OpConversionPattern<complex::ConjOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
Value real =
|
|
complex::ReOp::create(rewriter, loc, elementType, adaptor.getComplex());
|
|
Value imag =
|
|
complex::ImOp::create(rewriter, loc, elementType, adaptor.getComplex());
|
|
Value negImag = arith::NegFOp::create(rewriter, loc, elementType, imag);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, real, negImag);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts lhs^y = (a+bi)^(c+di) to
|
|
/// (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
|
|
/// where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
|
|
static Value powOpConversionImpl(mlir::ImplicitLocOpBuilder &builder,
|
|
ComplexType type, Value lhs, Value c, Value d,
|
|
arith::FastMathFlags fmf) {
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
|
|
Value a = complex::ReOp::create(builder, lhs);
|
|
Value b = complex::ImOp::create(builder, lhs);
|
|
|
|
Value abs = complex::AbsOp::create(builder, lhs, fmf);
|
|
Value absToC = math::PowFOp::create(builder, abs, c, fmf);
|
|
|
|
Value negD = arith::NegFOp::create(builder, d, fmf);
|
|
Value argLhs = math::Atan2Op::create(builder, b, a, fmf);
|
|
Value negDArgLhs = arith::MulFOp::create(builder, negD, argLhs, fmf);
|
|
Value expNegDArgLhs = math::ExpOp::create(builder, negDArgLhs, fmf);
|
|
|
|
Value coeff = arith::MulFOp::create(builder, absToC, expNegDArgLhs, fmf);
|
|
Value lnAbs = math::LogOp::create(builder, abs, fmf);
|
|
Value cArgLhs = arith::MulFOp::create(builder, c, argLhs, fmf);
|
|
Value dLnAbs = arith::MulFOp::create(builder, d, lnAbs, fmf);
|
|
Value q = arith::AddFOp::create(builder, cArgLhs, dLnAbs, fmf);
|
|
Value cosQ = math::CosOp::create(builder, q, fmf);
|
|
Value sinQ = math::SinOp::create(builder, q, fmf);
|
|
|
|
Value inf = arith::ConstantOp::create(
|
|
builder, elementType,
|
|
builder.getFloatAttr(elementType,
|
|
APFloat::getInf(elementType.getFloatSemantics())));
|
|
Value zero = arith::ConstantOp::create(
|
|
builder, elementType, builder.getFloatAttr(elementType, 0.0));
|
|
Value one = arith::ConstantOp::create(builder, elementType,
|
|
builder.getFloatAttr(elementType, 1.0));
|
|
Value complexOne = complex::CreateOp::create(builder, type, one, zero);
|
|
Value complexZero = complex::CreateOp::create(builder, type, zero, zero);
|
|
Value complexInf = complex::CreateOp::create(builder, type, inf, zero);
|
|
|
|
// Case 0:
|
|
// d^c is 0 if d is 0 and c > 0. 0^0 is defined to be 1.0, see
|
|
// Branch Cuts for Complex Elementary Functions or Much Ado About
|
|
// Nothing's Sign Bit, W. Kahan, Section 10.
|
|
Value absEqZero =
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, abs, zero, fmf);
|
|
Value dEqZero =
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, d, zero, fmf);
|
|
Value cEqZero =
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, c, zero, fmf);
|
|
Value bEqZero =
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, b, zero, fmf);
|
|
|
|
Value zeroLeC =
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OLE, zero, c, fmf);
|
|
Value coeffCosQ = arith::MulFOp::create(builder, coeff, cosQ, fmf);
|
|
Value coeffSinQ = arith::MulFOp::create(builder, coeff, sinQ, fmf);
|
|
Value complexOneOrZero =
|
|
arith::SelectOp::create(builder, cEqZero, complexOne, complexZero);
|
|
Value coeffCosSin =
|
|
complex::CreateOp::create(builder, type, coeffCosQ, coeffSinQ);
|
|
Value cutoff0 = arith::SelectOp::create(
|
|
builder,
|
|
arith::AndIOp::create(
|
|
builder, arith::AndIOp::create(builder, absEqZero, dEqZero), zeroLeC),
|
|
complexOneOrZero, coeffCosSin);
|
|
|
|
// Case 1:
|
|
// x^0 is defined to be 1 for any x, see
|
|
// Branch Cuts for Complex Elementary Functions or Much Ado About
|
|
// Nothing's Sign Bit, W. Kahan, Section 10.
|
|
Value rhsEqZero = arith::AndIOp::create(builder, cEqZero, dEqZero);
|
|
Value cutoff1 =
|
|
arith::SelectOp::create(builder, rhsEqZero, complexOne, cutoff0);
|
|
|
|
// Case 2:
|
|
// 1^(c + d*i) = 1 + 0*i
|
|
Value lhsEqOne = arith::AndIOp::create(
|
|
builder,
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, one, fmf),
|
|
bEqZero);
|
|
Value cutoff2 =
|
|
arith::SelectOp::create(builder, lhsEqOne, complexOne, cutoff1);
|
|
|
|
// Case 3:
|
|
// inf^(c + 0*i) = inf + 0*i, c > 0
|
|
Value lhsEqInf = arith::AndIOp::create(
|
|
builder,
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OEQ, a, inf, fmf),
|
|
bEqZero);
|
|
Value rhsGt0 = arith::AndIOp::create(
|
|
builder, dEqZero,
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OGT, c, zero, fmf));
|
|
Value cutoff3 = arith::SelectOp::create(
|
|
builder, arith::AndIOp::create(builder, lhsEqInf, rhsGt0), complexInf,
|
|
cutoff2);
|
|
|
|
// Case 4:
|
|
// inf^(c + 0*i) = 0 + 0*i, c < 0
|
|
Value rhsLt0 = arith::AndIOp::create(
|
|
builder, dEqZero,
|
|
arith::CmpFOp::create(builder, arith::CmpFPredicate::OLT, c, zero, fmf));
|
|
Value cutoff4 = arith::SelectOp::create(
|
|
builder, arith::AndIOp::create(builder, lhsEqInf, rhsLt0), complexZero,
|
|
cutoff3);
|
|
|
|
return cutoff4;
|
|
}
|
|
|
|
struct PowOpConversion : public OpConversionPattern<complex::PowOp> {
|
|
using OpConversionPattern<complex::PowOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::PowOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
|
|
auto type = cast<ComplexType>(adaptor.getLhs().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
|
|
Value c = complex::ReOp::create(builder, elementType, adaptor.getRhs());
|
|
Value d = complex::ImOp::create(builder, elementType, adaptor.getRhs());
|
|
|
|
rewriter.replaceOp(op, {powOpConversionImpl(builder, type, adaptor.getLhs(),
|
|
c, d, op.getFastmath())});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct RsqrtOpConversion : public OpConversionPattern<complex::RsqrtOp> {
|
|
using OpConversionPattern<complex::RsqrtOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::RsqrtOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter);
|
|
auto type = cast<ComplexType>(adaptor.getComplex().getType());
|
|
auto elementType = cast<FloatType>(type.getElementType());
|
|
|
|
arith::FastMathFlags fmf = op.getFastMathFlagsAttr().getValue();
|
|
|
|
auto cst = [&](APFloat v) {
|
|
return arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, v));
|
|
};
|
|
const auto &floatSemantics = elementType.getFloatSemantics();
|
|
Value zero = cst(APFloat::getZero(floatSemantics));
|
|
Value inf = cst(APFloat::getInf(floatSemantics));
|
|
Value negHalf = arith::ConstantOp::create(
|
|
b, elementType, b.getFloatAttr(elementType, -0.5));
|
|
Value nan = cst(APFloat::getNaN(floatSemantics));
|
|
|
|
Value real = complex::ReOp::create(b, elementType, adaptor.getComplex());
|
|
Value imag = complex::ImOp::create(b, elementType, adaptor.getComplex());
|
|
Value absRsqrt = computeAbs(real, imag, fmf, b, AbsFn::rsqrt);
|
|
Value argArg = math::Atan2Op::create(b, imag, real, fmf);
|
|
Value rsqrtArg = arith::MulFOp::create(b, argArg, negHalf, fmf);
|
|
Value cos = math::CosOp::create(b, rsqrtArg, fmf);
|
|
Value sin = math::SinOp::create(b, rsqrtArg, fmf);
|
|
|
|
Value resultReal = arith::MulFOp::create(b, absRsqrt, cos, fmf);
|
|
Value resultImag = arith::MulFOp::create(b, absRsqrt, sin, fmf);
|
|
|
|
if (!arith::bitEnumContainsAll(fmf, arith::FastMathFlags::nnan |
|
|
arith::FastMathFlags::ninf)) {
|
|
Value negOne = arith::ConstantOp::create(b, elementType,
|
|
b.getFloatAttr(elementType, -1));
|
|
|
|
Value realSignedZero = math::CopySignOp::create(b, zero, real, fmf);
|
|
Value imagSignedZero = math::CopySignOp::create(b, zero, imag, fmf);
|
|
Value negImagSignedZero =
|
|
arith::MulFOp::create(b, negOne, imagSignedZero, fmf);
|
|
|
|
Value absReal = math::AbsFOp::create(b, real, fmf);
|
|
Value absImag = math::AbsFOp::create(b, imag, fmf);
|
|
|
|
Value absImagIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
|
|
absImag, inf, fmf);
|
|
Value realIsNan =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::UNO, real, real, fmf);
|
|
Value realIsInf = arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ,
|
|
absReal, inf, fmf);
|
|
Value inIsNanInf = arith::AndIOp::create(b, absImagIsInf, realIsNan);
|
|
|
|
Value resultIsZero = arith::OrIOp::create(b, inIsNanInf, realIsInf);
|
|
|
|
resultReal =
|
|
arith::SelectOp::create(b, resultIsZero, realSignedZero, resultReal);
|
|
resultImag = arith::SelectOp::create(b, resultIsZero, negImagSignedZero,
|
|
resultImag);
|
|
}
|
|
|
|
Value isRealZero =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, real, zero, fmf);
|
|
Value isImagZero =
|
|
arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, imag, zero, fmf);
|
|
Value isZero = arith::AndIOp::create(b, isRealZero, isImagZero);
|
|
|
|
resultReal = arith::SelectOp::create(b, isZero, inf, resultReal);
|
|
resultImag = arith::SelectOp::create(b, isZero, nan, resultImag);
|
|
|
|
rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, resultReal,
|
|
resultImag);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct AngleOpConversion : public OpConversionPattern<complex::AngleOp> {
|
|
using OpConversionPattern<complex::AngleOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(complex::AngleOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto loc = op.getLoc();
|
|
auto type = op.getType();
|
|
arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr();
|
|
|
|
Value real =
|
|
complex::ReOp::create(rewriter, loc, type, adaptor.getComplex());
|
|
Value imag =
|
|
complex::ImOp::create(rewriter, loc, type, adaptor.getComplex());
|
|
|
|
rewriter.replaceOpWithNewOp<math::Atan2Op>(op, imag, real, fmf);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::populateComplexToStandardConversionPatterns(
|
|
RewritePatternSet &patterns, complex::ComplexRangeFlags complexRange) {
|
|
// clang-format off
|
|
patterns.add<
|
|
AbsOpConversion,
|
|
AngleOpConversion,
|
|
Atan2OpConversion,
|
|
BinaryComplexOpConversion<complex::AddOp, arith::AddFOp>,
|
|
BinaryComplexOpConversion<complex::SubOp, arith::SubFOp>,
|
|
ComparisonOpConversion<complex::EqualOp, arith::CmpFPredicate::OEQ>,
|
|
ComparisonOpConversion<complex::NotEqualOp, arith::CmpFPredicate::UNE>,
|
|
ConjOpConversion,
|
|
CosOpConversion,
|
|
ExpOpConversion,
|
|
Expm1OpConversion,
|
|
Log1pOpConversion,
|
|
LogOpConversion,
|
|
MulOpConversion,
|
|
NegOpConversion,
|
|
SignOpConversion,
|
|
SinOpConversion,
|
|
SqrtOpConversion,
|
|
TanTanhOpConversion<complex::TanOp>,
|
|
TanTanhOpConversion<complex::TanhOp>,
|
|
PowOpConversion,
|
|
RsqrtOpConversion
|
|
>(patterns.getContext());
|
|
|
|
patterns.add<DivOpConversion>(patterns.getContext(), complexRange);
|
|
|
|
// clang-format on
|
|
}
|
|
|
|
namespace {
|
|
struct ConvertComplexToStandardPass
|
|
: public impl::ConvertComplexToStandardPassBase<
|
|
ConvertComplexToStandardPass> {
|
|
using Base::Base;
|
|
|
|
void runOnOperation() override;
|
|
};
|
|
|
|
void ConvertComplexToStandardPass::runOnOperation() {
|
|
// Convert to the Standard dialect using the converter defined above.
|
|
RewritePatternSet patterns(&getContext());
|
|
populateComplexToStandardConversionPatterns(patterns, complexRange);
|
|
|
|
ConversionTarget target(getContext());
|
|
target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
|
|
target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
|
|
if (failed(
|
|
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|
|
} // namespace
|