[MLIR][Math] Add erfc to math dialect (#126439)

This patch adds the erfc op to the math dialect. It also does lowering
of the math.erfc op to libm calls. There is also a f32 polynomial
approximation for the function based on

https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
This is in turn based on
M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
(1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol.
36, No. 153, January 1981, pp. 249-253.
The code has a ULP error less than 3, which was tested, and MLIR test
values were verified against the C implementation.
This commit is contained in:
Jan Leyonberg 2025-02-18 10:51:37 -05:00 committed by GitHub
parent e1a393e392
commit 8806311bb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 353 additions and 5 deletions

View File

@ -560,6 +560,31 @@ def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// ErfcOp
//===----------------------------------------------------------------------===//
def Math_ErfcOp : Math_FloatUnaryOp<"erfc"> {
let summary = "complementary error function of the specified value";
let description = [{
The `erfc` operation computes the complementary error function, defined as
1-erf(x). This function is part of libm and is needed for accuracy, since
simply calculating 1-erf(x) when x is close to 1 will give inaccurate results.
It takes one operand of floating point type (i.e., scalar,
tensor or vector) and returns one result of the same type. It has no
standard attributes.
Example:
```mlir
// Scalar error function value.
%a = math.erfc %b : f64
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// ExpOp

View File

@ -23,6 +23,14 @@ public:
PatternRewriter &rewriter) const final;
};
struct ErfcPolynomialApproximation : public OpRewritePattern<math::ErfcOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::ErfcOp op,
PatternRewriter &rewriter) const final;
};
} // namespace math
} // namespace mlir

View File

@ -47,6 +47,7 @@ struct MathPolynomialApproximationOptions {
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
// Adds patterns to convert to f32 around math functions for which `predicate`
// returns true.

View File

@ -181,6 +181,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
populatePatternsForOp<math::ErfcOp>(patterns, benefit, ctx, "erfcf", "erfc");
populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f",

View File

@ -332,6 +332,24 @@ OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
});
}
//===----------------------------------------------------------------------===//
// ErfcOp folder
//===----------------------------------------------------------------------===//
OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
switch (APFloat::SemanticsToEnum(a.getSemantics())) {
case APFloat::Semantics::S_IEEEdouble:
return APFloat(erfc(a.convertToDouble()));
case APFloat::Semantics::S_IEEEsingle:
return APFloat(erfcf(a.convertToFloat()));
default:
return {};
}
});
}
//===----------------------------------------------------------------------===//
// IPowIOp folder
//===----------------------------------------------------------------------===//

View File

@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Helper functions to create constants.
//----------------------------------------------------------------------------//
static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
return builder.create<arith::ConstantOp>(builder.getBoolAttr(value));
}
static Value floatCst(ImplicitLocOpBuilder &builder, float value,
Type elementType) {
assert((elementType.isF16() || elementType.isF32()) &&
@ -1118,6 +1122,103 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
return success();
}
// Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
// polynomial.This approximation is based on the following stackoverflow post:
// https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
// The stackoverflow post is in turn based on:
// M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
// (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
// No. 153, January 1981, pp. 249-253.
//
// Maximum error: 2.65 ulps
LogicalResult
ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
PatternRewriter &rewriter) const {
Value x = op.getOperand();
Type et = getElementTypeOrSelf(x);
if (!et.isF32())
return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
std::optional<VectorShape> shape = vectorShape(x);
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
};
Value trueValue = bcast(boolCst(builder, true));
Value zero = bcast(floatCst(builder, 0.0f, et));
Value one = bcast(floatCst(builder, 1.0f, et));
Value onehalf = bcast(floatCst(builder, 0.5f, et));
Value neg4 = bcast(floatCst(builder, -4.0f, et));
Value neg2 = bcast(floatCst(builder, -2.0f, et));
Value pos2 = bcast(floatCst(builder, 2.0f, et));
Value posInf = bcast(floatCst(builder, INFINITY, et));
Value clampVal = bcast(floatCst(builder, 10.0546875f, et));
Value a = builder.create<math::AbsFOp>(x);
Value p = builder.create<arith::AddFOp>(a, pos2);
Value r = builder.create<arith::DivFOp>(one, p);
Value q = builder.create<math::FmaOp>(neg4, r, one);
Value t = builder.create<math::FmaOp>(builder.create<arith::AddFOp>(q, one),
neg2, a);
Value e = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), q, t);
q = builder.create<math::FmaOp>(r, e, q);
p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4
Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
p = builder.create<math::FmaOp>(p, q, c1);
Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3
p = builder.create<math::FmaOp>(p, q, c2);
Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
p = builder.create<math::FmaOp>(p, q, c3);
Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
p = builder.create<math::FmaOp>(p, q, c4);
Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
p = builder.create<math::FmaOp>(p, q, c5);
Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1
p = builder.create<math::FmaOp>(p, q, c6);
Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
p = builder.create<math::FmaOp>(p, q, c7);
Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
p = builder.create<math::FmaOp>(p, q, c8);
Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
p = builder.create<math::FmaOp>(p, q, c9);
Value d = builder.create<math::FmaOp>(pos2, a, one);
r = builder.create<arith::DivFOp>(one, d);
q = builder.create<math::FmaOp>(p, r, r);
Value negfa = builder.create<arith::NegFOp>(a);
Value fmaqah = builder.create<math::FmaOp>(q, negfa, onehalf);
Value psubq = builder.create<arith::SubFOp>(p, q);
e = builder.create<math::FmaOp>(fmaqah, pos2, psubq);
r = builder.create<math::FmaOp>(e, r, q);
Value s = builder.create<arith::MulFOp>(a, a);
e = builder.create<math::ExpOp>(builder.create<arith::NegFOp>(s));
t = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), a, s);
r = builder.create<math::FmaOp>(
r, e,
builder.create<arith::MulFOp>(builder.create<arith::MulFOp>(r, e), t));
Value isNotLessThanInf = builder.create<arith::XOrIOp>(
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
trueValue);
r = builder.create<arith::SelectOp>(isNotLessThanInf,
builder.create<arith::AddFOp>(x, x), r);
Value isGreaterThanClamp =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
r = builder.create<arith::SelectOp>(isGreaterThanClamp, zero, r);
Value isNegative =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
r = builder.create<arith::SelectOp>(
isNegative, builder.create<arith::SubFOp>(pos2, r), r);
rewriter.replaceOp(op, r);
return success();
}
//----------------------------------------------------------------------------//
// Exp approximation.
//----------------------------------------------------------------------------//
@ -1667,6 +1768,11 @@ void mlir::populatePolynomialApproximateErfPattern(
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
}
void mlir::populatePolynomialApproximateErfcPattern(
RewritePatternSet &patterns) {
patterns.add<ErfcPolynomialApproximation>(patterns.getContext());
}
template <typename OpType>
static void
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
@ -1690,6 +1796,7 @@ void mlir::populateMathF32ExpansionPatterns(
populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
@ -1734,6 +1841,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
patterns, predicate);
populateMathPolynomialApproximationPattern<ErfcOp,
ErfcPolynomialApproximation>(
patterns, predicate);
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
patterns, predicate);
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
@ -1760,9 +1870,10 @@ void mlir::populateMathPolynomialApproximationPatterns(
{math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
math::ErfOp::getOperationName(), math::ExpOp::getOperationName(),
math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(),
math::SinOp::getOperationName(), math::CosOp::getOperationName()},
math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
math::CosOp::getOperationName()},
name);
});
@ -1774,8 +1885,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
math::Log2Op::getOperationName(),
math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
math::AsinOp::getOperationName(), math::AcosOp::getOperationName(),
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
math::ExpM1Op::getOperationName(),
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
math::CosOp::getOperationName()},
name);

View File

@ -81,6 +81,116 @@ func.func @erf_scalar(%arg0: f32) -> f32 {
return %0 : f32
}
// CHECK-LABEL: func @erfc_scalar(
// CHECK-SAME: %[[val_arg0:.*]]: f32) -> f32 {
// CHECK-DAG: %[[c127_i32:.*]] = arith.constant 127 : i32
// CHECK-DAG: %[[c23_i32:.*]] = arith.constant 23 : i32
// CHECK-DAG: %[[cst:.*]] = arith.constant 1.270000e+02 : f32
// CHECK-DAG: %[[cst_0:.*]] = arith.constant -1.270000e+02 : f32
// CHECK-DAG: %[[cst_1:.*]] = arith.constant 8.880000e+01 : f32
// CHECK-DAG: %[[cst_2:.*]] = arith.constant -8.780000e+01 : f32
// CHECK-DAG: %[[cst_3:.*]] = arith.constant 0.166666657 : f32
// CHECK-DAG: %[[cst_4:.*]] = arith.constant 0.0416657962 : f32
// CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.00833345205 : f32
// CHECK-DAG: %[[cst_6:.*]] = arith.constant 0.00139819994 : f32
// CHECK-DAG: %[[cst_7:.*]] = arith.constant 1.98756912E-4 : f32
// CHECK-DAG: %[[cst_8:.*]] = arith.constant 2.12194442E-4 : f32
// CHECK-DAG: %[[cst_9:.*]] = arith.constant -0.693359375 : f32
// CHECK-DAG: %[[cst_10:.*]] = arith.constant 1.44269502 : f32
// CHECK-DAG: %[[cst_11:.*]] = arith.constant 0.276978403 : f32
// CHECK-DAG: %[[cst_12:.*]] = arith.constant -0.0927639827 : f32
// CHECK-DAG: %[[cst_13:.*]] = arith.constant -0.166031361 : f32
// CHECK-DAG: %[[cst_14:.*]] = arith.constant 0.164055392 : f32
// CHECK-DAG: %[[cst_15:.*]] = arith.constant -0.0542046614 : f32
// CHECK-DAG: %[[cst_16:.*]] = arith.constant -8.059920e-03 : f32
// CHECK-DAG: %[[cst_17:.*]] = arith.constant 0.00863227434 : f32
// CHECK-DAG: %[[cst_18:.*]] = arith.constant 0.00131355342 : f32
// CHECK-DAG: %[[cst_19:.*]] = arith.constant -0.0012307521 : f32
// CHECK-DAG: %[[cst_20:.*]] = arith.constant -4.01139259E-4 : f32
// CHECK-DAG: %[[cst_true:.*]] = arith.constant true
// CHECK-DAG: %[[cst_21:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[cst_22:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-DAG: %[[cst_23:.*]] = arith.constant 5.000000e-01 : f32
// CHECK-DAG: %[[cst_24:.*]] = arith.constant -4.000000e+00 : f32
// CHECK-DAG: %[[cst_25:.*]] = arith.constant -2.000000e+00 : f32
// CHECK-DAG: %[[cst_26:.*]] = arith.constant 2.000000e+00 : f32
// CHECK-DAG: %[[cst_27:.*]] = arith.constant 0x7F800000 : f32
// CHECK-DAG: %[[cst_28:.*]] = arith.constant 10.0546875 : f32
// CHECK: %[[val_2:.*]] = math.absf %[[val_arg0]] : f32
// CHECK-NEXT: %[[val_3:.*]] = arith.addf %[[val_2]], %[[cst_26]] : f32
// CHECK-NEXT: %[[val_4:.*]] = arith.divf %[[cst_22]], %[[val_3]] : f32
// CHECK-NEXT: %[[val_5:.*]] = math.fma %[[cst_24]], %[[val_4]], %[[cst_22]] : f32
// CHECK-NEXT: %[[val_6:.*]] = arith.addf %[[val_5]], %[[cst_22]] : f32
// CHECK-NEXT: %[[val_7:.*]] = math.fma %[[val_6]], %[[cst_25]], %[[val_2]] : f32
// CHECK-NEXT: %[[val_8:.*]] = arith.negf %[[val_2]] : f32
// CHECK-NEXT: %[[val_9:.*]] = math.fma %[[val_8]], %[[val_5]], %[[val_7]] : f32
// CHECK-NEXT: %[[val_10:.*]] = math.fma %[[val_4]], %[[val_9]], %[[val_5]] : f32
// CHECK-NEXT: %[[val_11:.*]] = math.fma %[[cst_20]], %[[val_10]], %[[cst_19]] : f32
// CHECK-NEXT: %[[val_12:.*]] = math.fma %[[val_11]], %[[val_10]], %[[cst_18]] : f32
// CHECK-NEXT: %[[val_13:.*]] = math.fma %[[val_12]], %[[val_10]], %[[cst_17]] : f32
// CHECK-NEXT: %[[val_14:.*]] = math.fma %[[val_13]], %[[val_10]], %[[cst_16]] : f32
// CHECK-NEXT: %[[val_15:.*]] = math.fma %[[val_14]], %[[val_10]], %[[cst_15]] : f32
// CHECK-NEXT: %[[val_16:.*]] = math.fma %[[val_15]], %[[val_10]], %[[cst_14]] : f32
// CHECK-NEXT: %[[val_17:.*]] = math.fma %[[val_16]], %[[val_10]], %[[cst_13]] : f32
// CHECK-NEXT: %[[val_18:.*]] = math.fma %[[val_17]], %[[val_10]], %[[cst_12]] : f32
// CHECK-NEXT: %[[val_19:.*]] = math.fma %[[val_18]], %[[val_10]], %[[cst_11]] : f32
// CHECK-NEXT: %[[val_20:.*]] = math.fma %[[cst_26]], %[[val_2]], %[[cst_22]] : f32
// CHECK-NEXT: %[[val_21:.*]] = arith.divf %[[cst_22]], %[[val_20]] : f32
// CHECK-NEXT: %[[val_22:.*]] = math.fma %[[val_19]], %[[val_21]], %[[val_21]] : f32
// CHECK-NEXT: %[[val_23:.*]] = arith.negf %[[val_2]] : f32
// CHECK-NEXT: %[[val_24:.*]] = math.fma %[[val_22]], %[[val_23]], %[[cst_23]] : f32
// CHECK-NEXT: %[[val_25:.*]] = arith.subf %[[val_19]], %[[val_22]] : f32
// CHECK-NEXT: %[[val_26:.*]] = math.fma %[[val_24]], %[[cst_26]], %[[val_25]] : f32
// CHECK-NEXT: %[[val_27:.*]] = math.fma %[[val_26]], %[[val_21]], %[[val_22]] : f32
// CHECK-NEXT: %[[val_28:.*]] = arith.mulf %[[val_2]], %[[val_2]] : f32
// CHECK-NEXT: %[[val_29:.*]] = arith.negf %[[val_28]] : f32
// CHECK-NEXT: %[[val_30:.*]] = arith.cmpf uge, %[[val_29]], %[[cst_2]] : f32
// CHECK-NEXT: %[[val_31:.*]] = arith.select %[[val_30]], %[[val_29]], %[[cst_2]] : f32
// CHECK-NEXT: %[[val_32:.*]] = arith.cmpf ule, %[[val_31]], %[[cst_1]] : f32
// CHECK-NEXT: %[[val_33:.*]] = arith.select %[[val_32]], %[[val_31]], %[[cst_1]] : f32
// CHECK-NEXT: %[[val_34:.*]] = math.fma %[[val_33]], %[[cst_10]], %[[cst_23]] : f32
// CHECK-NEXT: %[[val_35:.*]] = math.floor %[[val_34]] : f32
// CHECK-NEXT: %[[val_36:.*]] = arith.cmpf uge, %[[val_35]], %[[cst_0]] : f32
// CHECK-NEXT: %[[val_37:.*]] = arith.select %[[val_36]], %[[val_35]], %[[cst_0]] : f32
// CHECK-NEXT: %[[val_38:.*]] = arith.cmpf ule, %[[val_37]], %[[cst]] : f32
// CHECK-NEXT: %[[val_39:.*]] = arith.select %[[val_38]], %[[val_37]], %[[cst]] : f32
// CHECK-NEXT: %[[val_40:.*]] = math.fma %[[cst_9]], %[[val_39]], %[[val_33]] : f32
// CHECK-NEXT: %[[val_41:.*]] = math.fma %[[cst_8]], %[[val_39]], %[[val_40]] : f32
// CHECK-NEXT: %[[val_42:.*]] = math.fma %[[val_41]], %[[cst_7]], %[[cst_6]] : f32
// CHECK-NEXT: %[[val_43:.*]] = math.fma %[[val_42]], %[[val_41]], %[[cst_5]] : f32
// CHECK-NEXT: %[[val_44:.*]] = math.fma %[[val_43]], %[[val_41]], %[[cst_4]] : f32
// CHECK-NEXT: %[[val_45:.*]] = math.fma %[[val_44]], %[[val_41]], %[[cst_3]] : f32
// CHECK-NEXT: %[[val_46:.*]] = math.fma %[[val_45]], %[[val_41]], %[[cst_23]] : f32
// CHECK-NEXT: %[[val_47:.*]] = arith.mulf %[[val_41]], %[[val_41]] : f32
// CHECK-NEXT: %[[val_48:.*]] = math.fma %[[val_46]], %[[val_47]], %[[val_41]] : f32
// CHECK-NEXT: %[[val_49:.*]] = arith.addf %[[val_48]], %[[cst_22]] : f32
// CHECK-NEXT: %[[val_50:.*]] = arith.fptosi %[[val_39]] : f32 to i32
// CHECK-NEXT: %[[val_51:.*]] = arith.addi %[[val_50]], %[[c127_i32]] : i32
// CHECK-NEXT: %[[val_52:.*]] = arith.shli %[[val_51]], %[[c23_i32]] : i32
// CHECK-NEXT: %[[val_53:.*]] = arith.bitcast %[[val_52]] : i32 to f32
// CHECK-NEXT: %[[val_54:.*]] = arith.mulf %[[val_49]], %[[val_53]] : f32
// CHECK-NEXT: %[[val_55:.*]] = arith.negf %[[val_2]] : f32
// CHECK-NEXT: %[[val_56:.*]] = math.fma %[[val_55]], %[[val_2]], %[[val_28]] : f32
// CHECK-NEXT: %[[val_57:.*]] = arith.mulf %[[val_27]], %[[val_54]] : f32
// CHECK-NEXT: %[[val_58:.*]] = arith.mulf %[[val_57]], %[[val_56]] : f32
// CHECK-NEXT: %[[val_59:.*]] = math.fma %[[val_27]], %[[val_54]], %[[val_58]] : f32
// CHECK-NEXT: %[[val_60:.*]] = arith.cmpf olt, %[[val_2]], %[[cst_27]] : f32
// CHECK-NEXT: %[[val_61:.*]] = arith.xori %[[val_60]], %[[cst_true]] : i1
// CHECK-NEXT: %[[val_62:.*]] = arith.addf %[[val_arg0]], %[[val_arg0]] : f32
// CHECK-NEXT: %[[val_63:.*]] = arith.select %[[val_61]], %[[val_62]], %[[val_59]] : f32
// CHECK-NEXT: %[[val_64:.*]] = arith.cmpf ogt, %[[val_2]], %[[cst_28]] : f32
// CHECK-NEXT: %[[val_65:.*]] = arith.select %[[val_64]], %[[cst_21]], %[[val_63]] : f32
// CHECK-NEXT: %[[val_66:.*]] = arith.cmpf olt, %[[val_arg0]], %[[cst_21]] : f32
// CHECK-NEXT: %[[val_67:.*]] = arith.subf %[[cst_26]], %[[val_65]] : f32
// CHECK-NEXT: %[[val_68:.*]] = arith.select %[[val_66]], %[[val_67]], %[[val_65]] : f32
// CHECK-NEXT: return %[[val_68]] : f32
// CHECK-NEXT: }
func.func @erfc_scalar(%arg0: f32) -> f32 {
%0 = math.erfc %arg0 : f32
return %0 : f32
}
// CHECK-LABEL: func @erf_vector(
// CHECK-SAME: %[[arg0:.*]]: vector<8xf32>) -> vector<8xf32> {
// CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32>

View File

@ -273,6 +273,77 @@ func.func @erf() {
return
}
// -------------------------------------------------------------------------- //
// Erfc.
// -------------------------------------------------------------------------- //
func.func @erfc_f32(%a : f32) {
%r = math.erfc %a : f32
vector.print %r : f32
return
}
func.func @erfc_4xf32(%a : vector<4xf32>) {
%r = math.erfc %a : vector<4xf32>
vector.print %r : vector<4xf32>
return
}
func.func @erfc() {
// CHECK: 1.00027
%val1 = arith.constant -2.431864e-4 : f32
call @erfc_f32(%val1) : (f32) -> ()
// CHECK: 0.257905
%val2 = arith.constant 0.79999 : f32
call @erfc_f32(%val2) : (f32) -> ()
// CHECK: 0.257899
%val3 = arith.constant 0.8 : f32
call @erfc_f32(%val3) : (f32) -> ()
// CHECK: 0.00467794
%val4 = arith.constant 1.99999 : f32
call @erfc_f32(%val4) : (f32) -> ()
// CHECK: 0.00467774
%val5 = arith.constant 2.0 : f32
call @erfc_f32(%val5) : (f32) -> ()
// CHECK: 1.13736e-07
%val6 = arith.constant 3.74999 : f32
call @erfc_f32(%val6) : (f32) -> ()
// CHECK: 1.13727e-07
%val7 = arith.constant 3.75 : f32
call @erfc_f32(%val7) : (f32) -> ()
// CHECK: 2
%negativeInf = arith.constant 0xff800000 : f32
call @erfc_f32(%negativeInf) : (f32) -> ()
// CHECK: 2, 2, 1.91376, 1.73145
%vecVals1 = arith.constant dense<[-3.4028235e+38, -4.54318, -1.2130899, -7.8234202e-01]> : vector<4xf32>
call @erfc_4xf32(%vecVals1) : (vector<4xf32>) -> ()
// CHECK: 1, 1, 1, 0.878681
%vecVals2 = arith.constant dense<[-1.1754944e-38, 0.0, 1.1754944e-38, 1.0793410e-01]> : vector<4xf32>
call @erfc_4xf32(%vecVals2) : (vector<4xf32>) -> ()
// CHECK: 0.0805235, 0.000931045, 6.40418e-08, 0
%vecVals3 = arith.constant dense<[1.23578, 2.34093, 3.82342, 3.4028235e+38]> : vector<4xf32>
call @erfc_4xf32(%vecVals3) : (vector<4xf32>) -> ()
// CHECK: 0
%inf = arith.constant 0x7f800000 : f32
call @erfc_f32(%inf) : (f32) -> ()
// CHECK: nan
%nan = arith.constant 0x7fc00000 : f32
call @erfc_f32(%nan) : (f32) -> ()
return
}
// -------------------------------------------------------------------------- //
// Exp.
// -------------------------------------------------------------------------- //
@ -772,6 +843,7 @@ func.func @main() {
call @log2(): () -> ()
call @log1p(): () -> ()
call @erf(): () -> ()
call @erfc(): () -> ()
call @exp(): () -> ()
call @expm1(): () -> ()
call @sin(): () -> ()

View File

@ -44,6 +44,7 @@ syn keyword mlirOps view
" Math ops.
syn match mlirOps /\<math\.erf\>/
syn match mlirOps /\<math\.erfc\>/
" Affine ops.
syn match mlirOps /\<affine\.apply\>/