[MLIR][Math] Add erfc to math dialect (#126439)
This patch adds the erfc op to the math dialect. It also does lowering of the math.erfc op to libm calls. There is also a f32 polynomial approximation for the function based on https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf This is in turn based on M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, No. 153, January 1981, pp. 249-253. The code has a ULP error less than 3, which was tested, and MLIR test values were verified against the C implementation.
This commit is contained in:
parent
e1a393e392
commit
8806311bb7
@ -560,6 +560,31 @@ def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ErfcOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Math_ErfcOp : Math_FloatUnaryOp<"erfc"> {
|
||||
let summary = "complementary error function of the specified value";
|
||||
let description = [{
|
||||
|
||||
The `erfc` operation computes the complementary error function, defined as
|
||||
1-erf(x). This function is part of libm and is needed for accuracy, since
|
||||
simply calculating 1-erf(x) when x is close to 1 will give inaccurate results.
|
||||
It takes one operand of floating point type (i.e., scalar,
|
||||
tensor or vector) and returns one result of the same type. It has no
|
||||
standard attributes.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar error function value.
|
||||
%a = math.erfc %b : f64
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExpOp
|
||||
|
||||
@ -23,6 +23,14 @@ public:
|
||||
PatternRewriter &rewriter) const final;
|
||||
};
|
||||
|
||||
struct ErfcPolynomialApproximation : public OpRewritePattern<math::ErfcOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(math::ErfcOp op,
|
||||
PatternRewriter &rewriter) const final;
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ struct MathPolynomialApproximationOptions {
|
||||
|
||||
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
|
||||
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
|
||||
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
|
||||
|
||||
// Adds patterns to convert to f32 around math functions for which `predicate`
|
||||
// returns true.
|
||||
|
||||
@ -181,6 +181,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
|
||||
populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
|
||||
populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
|
||||
populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
|
||||
populatePatternsForOp<math::ErfcOp>(patterns, benefit, ctx, "erfcf", "erfc");
|
||||
populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
|
||||
populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
|
||||
populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f",
|
||||
|
||||
@ -332,6 +332,24 @@ OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ErfcOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
|
||||
return constFoldUnaryOpConditional<FloatAttr>(
|
||||
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
|
||||
switch (APFloat::SemanticsToEnum(a.getSemantics())) {
|
||||
case APFloat::Semantics::S_IEEEdouble:
|
||||
return APFloat(erfc(a.convertToDouble()));
|
||||
case APFloat::Semantics::S_IEEEsingle:
|
||||
return APFloat(erfcf(a.convertToFloat()));
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IPowIOp folder
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
|
||||
// Helper functions to create constants.
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
|
||||
return builder.create<arith::ConstantOp>(builder.getBoolAttr(value));
|
||||
}
|
||||
|
||||
static Value floatCst(ImplicitLocOpBuilder &builder, float value,
|
||||
Type elementType) {
|
||||
assert((elementType.isF16() || elementType.isF32()) &&
|
||||
@ -1118,6 +1122,103 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
|
||||
return success();
|
||||
}
|
||||
|
||||
// Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
|
||||
// polynomial.This approximation is based on the following stackoverflow post:
|
||||
// https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
|
||||
// The stackoverflow post is in turn based on:
|
||||
// M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
|
||||
// (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
|
||||
// No. 153, January 1981, pp. 249-253.
|
||||
//
|
||||
// Maximum error: 2.65 ulps
|
||||
LogicalResult
|
||||
ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Value x = op.getOperand();
|
||||
Type et = getElementTypeOrSelf(x);
|
||||
|
||||
if (!et.isF32())
|
||||
return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
|
||||
std::optional<VectorShape> shape = vectorShape(x);
|
||||
|
||||
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
|
||||
auto bcast = [&](Value value) -> Value {
|
||||
return broadcast(builder, value, shape);
|
||||
};
|
||||
|
||||
Value trueValue = bcast(boolCst(builder, true));
|
||||
Value zero = bcast(floatCst(builder, 0.0f, et));
|
||||
Value one = bcast(floatCst(builder, 1.0f, et));
|
||||
Value onehalf = bcast(floatCst(builder, 0.5f, et));
|
||||
Value neg4 = bcast(floatCst(builder, -4.0f, et));
|
||||
Value neg2 = bcast(floatCst(builder, -2.0f, et));
|
||||
Value pos2 = bcast(floatCst(builder, 2.0f, et));
|
||||
Value posInf = bcast(floatCst(builder, INFINITY, et));
|
||||
Value clampVal = bcast(floatCst(builder, 10.0546875f, et));
|
||||
|
||||
Value a = builder.create<math::AbsFOp>(x);
|
||||
Value p = builder.create<arith::AddFOp>(a, pos2);
|
||||
Value r = builder.create<arith::DivFOp>(one, p);
|
||||
Value q = builder.create<math::FmaOp>(neg4, r, one);
|
||||
Value t = builder.create<math::FmaOp>(builder.create<arith::AddFOp>(q, one),
|
||||
neg2, a);
|
||||
Value e = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), q, t);
|
||||
q = builder.create<math::FmaOp>(r, e, q);
|
||||
|
||||
p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4
|
||||
Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
|
||||
p = builder.create<math::FmaOp>(p, q, c1);
|
||||
Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3
|
||||
p = builder.create<math::FmaOp>(p, q, c2);
|
||||
Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
|
||||
p = builder.create<math::FmaOp>(p, q, c3);
|
||||
Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
|
||||
p = builder.create<math::FmaOp>(p, q, c4);
|
||||
Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
|
||||
p = builder.create<math::FmaOp>(p, q, c5);
|
||||
Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1
|
||||
p = builder.create<math::FmaOp>(p, q, c6);
|
||||
Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
|
||||
p = builder.create<math::FmaOp>(p, q, c7);
|
||||
Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
|
||||
p = builder.create<math::FmaOp>(p, q, c8);
|
||||
Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
|
||||
p = builder.create<math::FmaOp>(p, q, c9);
|
||||
|
||||
Value d = builder.create<math::FmaOp>(pos2, a, one);
|
||||
r = builder.create<arith::DivFOp>(one, d);
|
||||
q = builder.create<math::FmaOp>(p, r, r);
|
||||
Value negfa = builder.create<arith::NegFOp>(a);
|
||||
Value fmaqah = builder.create<math::FmaOp>(q, negfa, onehalf);
|
||||
Value psubq = builder.create<arith::SubFOp>(p, q);
|
||||
e = builder.create<math::FmaOp>(fmaqah, pos2, psubq);
|
||||
r = builder.create<math::FmaOp>(e, r, q);
|
||||
|
||||
Value s = builder.create<arith::MulFOp>(a, a);
|
||||
e = builder.create<math::ExpOp>(builder.create<arith::NegFOp>(s));
|
||||
|
||||
t = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), a, s);
|
||||
r = builder.create<math::FmaOp>(
|
||||
r, e,
|
||||
builder.create<arith::MulFOp>(builder.create<arith::MulFOp>(r, e), t));
|
||||
|
||||
Value isNotLessThanInf = builder.create<arith::XOrIOp>(
|
||||
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
|
||||
trueValue);
|
||||
r = builder.create<arith::SelectOp>(isNotLessThanInf,
|
||||
builder.create<arith::AddFOp>(x, x), r);
|
||||
Value isGreaterThanClamp =
|
||||
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
|
||||
r = builder.create<arith::SelectOp>(isGreaterThanClamp, zero, r);
|
||||
|
||||
Value isNegative =
|
||||
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
|
||||
r = builder.create<arith::SelectOp>(
|
||||
isNegative, builder.create<arith::SubFOp>(pos2, r), r);
|
||||
|
||||
rewriter.replaceOp(op, r);
|
||||
return success();
|
||||
}
|
||||
//----------------------------------------------------------------------------//
|
||||
// Exp approximation.
|
||||
//----------------------------------------------------------------------------//
|
||||
@ -1667,6 +1768,11 @@ void mlir::populatePolynomialApproximateErfPattern(
|
||||
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::populatePolynomialApproximateErfcPattern(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<ErfcPolynomialApproximation>(patterns.getContext());
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
static void
|
||||
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
|
||||
@ -1690,6 +1796,7 @@ void mlir::populateMathF32ExpansionPatterns(
|
||||
populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
|
||||
populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
|
||||
populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
|
||||
populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
|
||||
populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
|
||||
populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
|
||||
populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
|
||||
@ -1734,6 +1841,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
|
||||
CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
|
||||
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
|
||||
patterns, predicate);
|
||||
populateMathPolynomialApproximationPattern<ErfcOp,
|
||||
ErfcPolynomialApproximation>(
|
||||
patterns, predicate);
|
||||
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
|
||||
patterns, predicate);
|
||||
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
|
||||
@ -1760,9 +1870,10 @@ void mlir::populateMathPolynomialApproximationPatterns(
|
||||
{math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
|
||||
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
|
||||
math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
|
||||
math::ErfOp::getOperationName(), math::ExpOp::getOperationName(),
|
||||
math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(),
|
||||
math::SinOp::getOperationName(), math::CosOp::getOperationName()},
|
||||
math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
|
||||
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
|
||||
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
|
||||
math::CosOp::getOperationName()},
|
||||
name);
|
||||
});
|
||||
|
||||
@ -1774,8 +1885,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
|
||||
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
|
||||
math::Log2Op::getOperationName(),
|
||||
math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
|
||||
math::AsinOp::getOperationName(), math::AcosOp::getOperationName(),
|
||||
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
|
||||
math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
|
||||
math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
|
||||
math::ExpM1Op::getOperationName(),
|
||||
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
|
||||
math::CosOp::getOperationName()},
|
||||
name);
|
||||
|
||||
@ -81,6 +81,116 @@ func.func @erf_scalar(%arg0: f32) -> f32 {
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @erfc_scalar(
|
||||
// CHECK-SAME: %[[val_arg0:.*]]: f32) -> f32 {
|
||||
// CHECK-DAG: %[[c127_i32:.*]] = arith.constant 127 : i32
|
||||
// CHECK-DAG: %[[c23_i32:.*]] = arith.constant 23 : i32
|
||||
// CHECK-DAG: %[[cst:.*]] = arith.constant 1.270000e+02 : f32
|
||||
// CHECK-DAG: %[[cst_0:.*]] = arith.constant -1.270000e+02 : f32
|
||||
// CHECK-DAG: %[[cst_1:.*]] = arith.constant 8.880000e+01 : f32
|
||||
// CHECK-DAG: %[[cst_2:.*]] = arith.constant -8.780000e+01 : f32
|
||||
// CHECK-DAG: %[[cst_3:.*]] = arith.constant 0.166666657 : f32
|
||||
// CHECK-DAG: %[[cst_4:.*]] = arith.constant 0.0416657962 : f32
|
||||
// CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.00833345205 : f32
|
||||
// CHECK-DAG: %[[cst_6:.*]] = arith.constant 0.00139819994 : f32
|
||||
// CHECK-DAG: %[[cst_7:.*]] = arith.constant 1.98756912E-4 : f32
|
||||
// CHECK-DAG: %[[cst_8:.*]] = arith.constant 2.12194442E-4 : f32
|
||||
// CHECK-DAG: %[[cst_9:.*]] = arith.constant -0.693359375 : f32
|
||||
// CHECK-DAG: %[[cst_10:.*]] = arith.constant 1.44269502 : f32
|
||||
// CHECK-DAG: %[[cst_11:.*]] = arith.constant 0.276978403 : f32
|
||||
// CHECK-DAG: %[[cst_12:.*]] = arith.constant -0.0927639827 : f32
|
||||
// CHECK-DAG: %[[cst_13:.*]] = arith.constant -0.166031361 : f32
|
||||
// CHECK-DAG: %[[cst_14:.*]] = arith.constant 0.164055392 : f32
|
||||
// CHECK-DAG: %[[cst_15:.*]] = arith.constant -0.0542046614 : f32
|
||||
// CHECK-DAG: %[[cst_16:.*]] = arith.constant -8.059920e-03 : f32
|
||||
// CHECK-DAG: %[[cst_17:.*]] = arith.constant 0.00863227434 : f32
|
||||
// CHECK-DAG: %[[cst_18:.*]] = arith.constant 0.00131355342 : f32
|
||||
// CHECK-DAG: %[[cst_19:.*]] = arith.constant -0.0012307521 : f32
|
||||
// CHECK-DAG: %[[cst_20:.*]] = arith.constant -4.01139259E-4 : f32
|
||||
// CHECK-DAG: %[[cst_true:.*]] = arith.constant true
|
||||
// CHECK-DAG: %[[cst_21:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[cst_22:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK-DAG: %[[cst_23:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK-DAG: %[[cst_24:.*]] = arith.constant -4.000000e+00 : f32
|
||||
// CHECK-DAG: %[[cst_25:.*]] = arith.constant -2.000000e+00 : f32
|
||||
// CHECK-DAG: %[[cst_26:.*]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK-DAG: %[[cst_27:.*]] = arith.constant 0x7F800000 : f32
|
||||
// CHECK-DAG: %[[cst_28:.*]] = arith.constant 10.0546875 : f32
|
||||
// CHECK: %[[val_2:.*]] = math.absf %[[val_arg0]] : f32
|
||||
// CHECK-NEXT: %[[val_3:.*]] = arith.addf %[[val_2]], %[[cst_26]] : f32
|
||||
// CHECK-NEXT: %[[val_4:.*]] = arith.divf %[[cst_22]], %[[val_3]] : f32
|
||||
// CHECK-NEXT: %[[val_5:.*]] = math.fma %[[cst_24]], %[[val_4]], %[[cst_22]] : f32
|
||||
// CHECK-NEXT: %[[val_6:.*]] = arith.addf %[[val_5]], %[[cst_22]] : f32
|
||||
// CHECK-NEXT: %[[val_7:.*]] = math.fma %[[val_6]], %[[cst_25]], %[[val_2]] : f32
|
||||
// CHECK-NEXT: %[[val_8:.*]] = arith.negf %[[val_2]] : f32
|
||||
// CHECK-NEXT: %[[val_9:.*]] = math.fma %[[val_8]], %[[val_5]], %[[val_7]] : f32
|
||||
// CHECK-NEXT: %[[val_10:.*]] = math.fma %[[val_4]], %[[val_9]], %[[val_5]] : f32
|
||||
// CHECK-NEXT: %[[val_11:.*]] = math.fma %[[cst_20]], %[[val_10]], %[[cst_19]] : f32
|
||||
// CHECK-NEXT: %[[val_12:.*]] = math.fma %[[val_11]], %[[val_10]], %[[cst_18]] : f32
|
||||
// CHECK-NEXT: %[[val_13:.*]] = math.fma %[[val_12]], %[[val_10]], %[[cst_17]] : f32
|
||||
// CHECK-NEXT: %[[val_14:.*]] = math.fma %[[val_13]], %[[val_10]], %[[cst_16]] : f32
|
||||
// CHECK-NEXT: %[[val_15:.*]] = math.fma %[[val_14]], %[[val_10]], %[[cst_15]] : f32
|
||||
// CHECK-NEXT: %[[val_16:.*]] = math.fma %[[val_15]], %[[val_10]], %[[cst_14]] : f32
|
||||
// CHECK-NEXT: %[[val_17:.*]] = math.fma %[[val_16]], %[[val_10]], %[[cst_13]] : f32
|
||||
// CHECK-NEXT: %[[val_18:.*]] = math.fma %[[val_17]], %[[val_10]], %[[cst_12]] : f32
|
||||
// CHECK-NEXT: %[[val_19:.*]] = math.fma %[[val_18]], %[[val_10]], %[[cst_11]] : f32
|
||||
// CHECK-NEXT: %[[val_20:.*]] = math.fma %[[cst_26]], %[[val_2]], %[[cst_22]] : f32
|
||||
// CHECK-NEXT: %[[val_21:.*]] = arith.divf %[[cst_22]], %[[val_20]] : f32
|
||||
// CHECK-NEXT: %[[val_22:.*]] = math.fma %[[val_19]], %[[val_21]], %[[val_21]] : f32
|
||||
// CHECK-NEXT: %[[val_23:.*]] = arith.negf %[[val_2]] : f32
|
||||
// CHECK-NEXT: %[[val_24:.*]] = math.fma %[[val_22]], %[[val_23]], %[[cst_23]] : f32
|
||||
// CHECK-NEXT: %[[val_25:.*]] = arith.subf %[[val_19]], %[[val_22]] : f32
|
||||
// CHECK-NEXT: %[[val_26:.*]] = math.fma %[[val_24]], %[[cst_26]], %[[val_25]] : f32
|
||||
// CHECK-NEXT: %[[val_27:.*]] = math.fma %[[val_26]], %[[val_21]], %[[val_22]] : f32
|
||||
// CHECK-NEXT: %[[val_28:.*]] = arith.mulf %[[val_2]], %[[val_2]] : f32
|
||||
// CHECK-NEXT: %[[val_29:.*]] = arith.negf %[[val_28]] : f32
|
||||
// CHECK-NEXT: %[[val_30:.*]] = arith.cmpf uge, %[[val_29]], %[[cst_2]] : f32
|
||||
// CHECK-NEXT: %[[val_31:.*]] = arith.select %[[val_30]], %[[val_29]], %[[cst_2]] : f32
|
||||
// CHECK-NEXT: %[[val_32:.*]] = arith.cmpf ule, %[[val_31]], %[[cst_1]] : f32
|
||||
// CHECK-NEXT: %[[val_33:.*]] = arith.select %[[val_32]], %[[val_31]], %[[cst_1]] : f32
|
||||
// CHECK-NEXT: %[[val_34:.*]] = math.fma %[[val_33]], %[[cst_10]], %[[cst_23]] : f32
|
||||
// CHECK-NEXT: %[[val_35:.*]] = math.floor %[[val_34]] : f32
|
||||
// CHECK-NEXT: %[[val_36:.*]] = arith.cmpf uge, %[[val_35]], %[[cst_0]] : f32
|
||||
// CHECK-NEXT: %[[val_37:.*]] = arith.select %[[val_36]], %[[val_35]], %[[cst_0]] : f32
|
||||
// CHECK-NEXT: %[[val_38:.*]] = arith.cmpf ule, %[[val_37]], %[[cst]] : f32
|
||||
// CHECK-NEXT: %[[val_39:.*]] = arith.select %[[val_38]], %[[val_37]], %[[cst]] : f32
|
||||
// CHECK-NEXT: %[[val_40:.*]] = math.fma %[[cst_9]], %[[val_39]], %[[val_33]] : f32
|
||||
// CHECK-NEXT: %[[val_41:.*]] = math.fma %[[cst_8]], %[[val_39]], %[[val_40]] : f32
|
||||
// CHECK-NEXT: %[[val_42:.*]] = math.fma %[[val_41]], %[[cst_7]], %[[cst_6]] : f32
|
||||
// CHECK-NEXT: %[[val_43:.*]] = math.fma %[[val_42]], %[[val_41]], %[[cst_5]] : f32
|
||||
// CHECK-NEXT: %[[val_44:.*]] = math.fma %[[val_43]], %[[val_41]], %[[cst_4]] : f32
|
||||
// CHECK-NEXT: %[[val_45:.*]] = math.fma %[[val_44]], %[[val_41]], %[[cst_3]] : f32
|
||||
// CHECK-NEXT: %[[val_46:.*]] = math.fma %[[val_45]], %[[val_41]], %[[cst_23]] : f32
|
||||
// CHECK-NEXT: %[[val_47:.*]] = arith.mulf %[[val_41]], %[[val_41]] : f32
|
||||
// CHECK-NEXT: %[[val_48:.*]] = math.fma %[[val_46]], %[[val_47]], %[[val_41]] : f32
|
||||
// CHECK-NEXT: %[[val_49:.*]] = arith.addf %[[val_48]], %[[cst_22]] : f32
|
||||
// CHECK-NEXT: %[[val_50:.*]] = arith.fptosi %[[val_39]] : f32 to i32
|
||||
// CHECK-NEXT: %[[val_51:.*]] = arith.addi %[[val_50]], %[[c127_i32]] : i32
|
||||
// CHECK-NEXT: %[[val_52:.*]] = arith.shli %[[val_51]], %[[c23_i32]] : i32
|
||||
// CHECK-NEXT: %[[val_53:.*]] = arith.bitcast %[[val_52]] : i32 to f32
|
||||
// CHECK-NEXT: %[[val_54:.*]] = arith.mulf %[[val_49]], %[[val_53]] : f32
|
||||
// CHECK-NEXT: %[[val_55:.*]] = arith.negf %[[val_2]] : f32
|
||||
// CHECK-NEXT: %[[val_56:.*]] = math.fma %[[val_55]], %[[val_2]], %[[val_28]] : f32
|
||||
// CHECK-NEXT: %[[val_57:.*]] = arith.mulf %[[val_27]], %[[val_54]] : f32
|
||||
// CHECK-NEXT: %[[val_58:.*]] = arith.mulf %[[val_57]], %[[val_56]] : f32
|
||||
// CHECK-NEXT: %[[val_59:.*]] = math.fma %[[val_27]], %[[val_54]], %[[val_58]] : f32
|
||||
// CHECK-NEXT: %[[val_60:.*]] = arith.cmpf olt, %[[val_2]], %[[cst_27]] : f32
|
||||
// CHECK-NEXT: %[[val_61:.*]] = arith.xori %[[val_60]], %[[cst_true]] : i1
|
||||
// CHECK-NEXT: %[[val_62:.*]] = arith.addf %[[val_arg0]], %[[val_arg0]] : f32
|
||||
// CHECK-NEXT: %[[val_63:.*]] = arith.select %[[val_61]], %[[val_62]], %[[val_59]] : f32
|
||||
// CHECK-NEXT: %[[val_64:.*]] = arith.cmpf ogt, %[[val_2]], %[[cst_28]] : f32
|
||||
// CHECK-NEXT: %[[val_65:.*]] = arith.select %[[val_64]], %[[cst_21]], %[[val_63]] : f32
|
||||
// CHECK-NEXT: %[[val_66:.*]] = arith.cmpf olt, %[[val_arg0]], %[[cst_21]] : f32
|
||||
// CHECK-NEXT: %[[val_67:.*]] = arith.subf %[[cst_26]], %[[val_65]] : f32
|
||||
// CHECK-NEXT: %[[val_68:.*]] = arith.select %[[val_66]], %[[val_67]], %[[val_65]] : f32
|
||||
// CHECK-NEXT: return %[[val_68]] : f32
|
||||
// CHECK-NEXT: }
|
||||
|
||||
func.func @erfc_scalar(%arg0: f32) -> f32 {
|
||||
%0 = math.erfc %arg0 : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @erf_vector(
|
||||
// CHECK-SAME: %[[arg0:.*]]: vector<8xf32>) -> vector<8xf32> {
|
||||
// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>
|
||||
|
||||
@ -273,6 +273,77 @@ func.func @erf() {
|
||||
return
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// Erfc.
|
||||
// -------------------------------------------------------------------------- //
|
||||
func.func @erfc_f32(%a : f32) {
|
||||
%r = math.erfc %a : f32
|
||||
vector.print %r : f32
|
||||
return
|
||||
}
|
||||
|
||||
func.func @erfc_4xf32(%a : vector<4xf32>) {
|
||||
%r = math.erfc %a : vector<4xf32>
|
||||
vector.print %r : vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func.func @erfc() {
|
||||
// CHECK: 1.00027
|
||||
%val1 = arith.constant -2.431864e-4 : f32
|
||||
call @erfc_f32(%val1) : (f32) -> ()
|
||||
|
||||
// CHECK: 0.257905
|
||||
%val2 = arith.constant 0.79999 : f32
|
||||
call @erfc_f32(%val2) : (f32) -> ()
|
||||
|
||||
// CHECK: 0.257899
|
||||
%val3 = arith.constant 0.8 : f32
|
||||
call @erfc_f32(%val3) : (f32) -> ()
|
||||
|
||||
// CHECK: 0.00467794
|
||||
%val4 = arith.constant 1.99999 : f32
|
||||
call @erfc_f32(%val4) : (f32) -> ()
|
||||
|
||||
// CHECK: 0.00467774
|
||||
%val5 = arith.constant 2.0 : f32
|
||||
call @erfc_f32(%val5) : (f32) -> ()
|
||||
|
||||
// CHECK: 1.13736e-07
|
||||
%val6 = arith.constant 3.74999 : f32
|
||||
call @erfc_f32(%val6) : (f32) -> ()
|
||||
|
||||
// CHECK: 1.13727e-07
|
||||
%val7 = arith.constant 3.75 : f32
|
||||
call @erfc_f32(%val7) : (f32) -> ()
|
||||
|
||||
// CHECK: 2
|
||||
%negativeInf = arith.constant 0xff800000 : f32
|
||||
call @erfc_f32(%negativeInf) : (f32) -> ()
|
||||
|
||||
// CHECK: 2, 2, 1.91376, 1.73145
|
||||
%vecVals1 = arith.constant dense<[-3.4028235e+38, -4.54318, -1.2130899, -7.8234202e-01]> : vector<4xf32>
|
||||
call @erfc_4xf32(%vecVals1) : (vector<4xf32>) -> ()
|
||||
|
||||
// CHECK: 1, 1, 1, 0.878681
|
||||
%vecVals2 = arith.constant dense<[-1.1754944e-38, 0.0, 1.1754944e-38, 1.0793410e-01]> : vector<4xf32>
|
||||
call @erfc_4xf32(%vecVals2) : (vector<4xf32>) -> ()
|
||||
|
||||
// CHECK: 0.0805235, 0.000931045, 6.40418e-08, 0
|
||||
%vecVals3 = arith.constant dense<[1.23578, 2.34093, 3.82342, 3.4028235e+38]> : vector<4xf32>
|
||||
call @erfc_4xf32(%vecVals3) : (vector<4xf32>) -> ()
|
||||
|
||||
// CHECK: 0
|
||||
%inf = arith.constant 0x7f800000 : f32
|
||||
call @erfc_f32(%inf) : (f32) -> ()
|
||||
|
||||
// CHECK: nan
|
||||
%nan = arith.constant 0x7fc00000 : f32
|
||||
call @erfc_f32(%nan) : (f32) -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// Exp.
|
||||
// -------------------------------------------------------------------------- //
|
||||
@ -772,6 +843,7 @@ func.func @main() {
|
||||
call @log2(): () -> ()
|
||||
call @log1p(): () -> ()
|
||||
call @erf(): () -> ()
|
||||
call @erfc(): () -> ()
|
||||
call @exp(): () -> ()
|
||||
call @expm1(): () -> ()
|
||||
call @sin(): () -> ()
|
||||
|
||||
@ -44,6 +44,7 @@ syn keyword mlirOps view
|
||||
|
||||
" Math ops.
|
||||
syn match mlirOps /\<math\.erf\>/
|
||||
syn match mlirOps /\<math\.erfc\>/
|
||||
|
||||
" Affine ops.
|
||||
syn match mlirOps /\<affine\.apply\>/
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user