llvm-project/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Eugene Zhulenev f99ccf6516 [mlir] Add math polynomial approximation pass
This gives ~30x speedup compared to expanding Tanh into exp operations:

```
name                  old cpu/op  new cpu/op  delta
BM_mlir_Tanh_f32/10    253ns ± 3%    55ns ± 7%  -78.35%  (p=0.000 n=44+41)
BM_mlir_Tanh_f32/100  2.21µs ± 4%  0.14µs ± 8%  -93.85%  (p=0.000 n=48+49)
BM_mlir_Tanh_f32/1k   22.6µs ± 4%   0.7µs ± 5%  -96.68%  (p=0.000 n=32+42)
BM_mlir_Tanh_f32/10k   225µs ± 5%     7µs ± 6%  -96.88%  (p=0.000 n=49+55)

name                  old time/op             new time/op             delta
BM_mlir_Tanh_f32/10    259ns ± 1%               56ns ± 2%  -78.31%        (p=0.000 n=41+39)
BM_mlir_Tanh_f32/100  2.27µs ± 1%             0.14µs ± 5%  -93.89%        (p=0.000 n=46+49)
BM_mlir_Tanh_f32/1k   22.9µs ± 1%              0.8µs ± 4%  -96.67%        (p=0.000 n=30+42)
BM_mlir_Tanh_f32/10k   230µs ± 0%                7µs ± 3%  -96.88%        (p=0.000 n=37+55)
```

This approximations is based on Eigen::generic_fast_tanh function

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96739
2021-02-19 12:43:36 -08:00

195 lines
6.7 KiB
C++

//===- PolynomialApproximation.cpp - Approximate math operations ----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements expansion of math operations to fast approximations
// that do not rely on any of the library functions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;
static bool isValidFloatType(Type type) {
if (auto vectorType = type.dyn_cast<VectorType>())
return vectorType.getElementType().isa<FloatType>();
return type.isa<FloatType>();
}
//----------------------------------------------------------------------------//
// A PatternRewriter wrapper that provides concise API for building expansions
// for operations on float scalars or vectors.
//----------------------------------------------------------------------------//
namespace {
class FloatApproximationBuilder {
public:
FloatApproximationBuilder(Location loc, Type type, PatternRewriter &rewriter);
Value constant(double value) const;
Value abs(Value a) const;
Value min(Value a, Value b) const;
Value max(Value a, Value b) const;
Value mul(Value a, Value b) const;
Value div(Value a, Value b) const;
// Fused multiple-add operation: a * b + c.
Value madd(Value a, Value b, Value c) const;
// Compares values `a` and `b` with the given `predicate`.
Value cmp(CmpFPredicate predicate, Value a, Value b) const;
// Selects values from `a` or `b` based on the `predicate`.
Value select(Value predicate, Value a, Value b) const;
private:
Location loc;
PatternRewriter &rewriter;
VectorType vectorType; // can be null for scalar type
FloatType elementType;
};
} // namespace
FloatApproximationBuilder::FloatApproximationBuilder(Location loc, Type type,
PatternRewriter &rewriter)
: loc(loc), rewriter(rewriter) {
vectorType = type.dyn_cast<VectorType>();
if (vectorType)
elementType = vectorType.getElementType().cast<FloatType>();
else
elementType = type.cast<FloatType>();
}
Value FloatApproximationBuilder::constant(double value) const {
auto attr = rewriter.getFloatAttr(elementType, value);
Value scalar = rewriter.create<ConstantOp>(loc, attr);
if (vectorType)
return rewriter.create<BroadcastOp>(loc, vectorType, scalar);
return scalar;
}
Value FloatApproximationBuilder::abs(Value a) const {
return rewriter.create<AbsFOp>(loc, a);
}
Value FloatApproximationBuilder::min(Value a, Value b) const {
return select(cmp(CmpFPredicate::OLT, a, b), a, b);
}
Value FloatApproximationBuilder::max(Value a, Value b) const {
return select(cmp(CmpFPredicate::OGT, a, b), a, b);
}
Value FloatApproximationBuilder::mul(Value a, Value b) const {
return rewriter.create<MulFOp>(loc, a, b);
}
Value FloatApproximationBuilder::div(Value a, Value b) const {
return rewriter.create<DivFOp>(loc, a, b);
}
Value FloatApproximationBuilder::madd(Value a, Value b, Value c) const {
return rewriter.create<FmaFOp>(loc, a, b, c);
}
Value FloatApproximationBuilder::cmp(CmpFPredicate predicate, Value a,
Value b) const {
return rewriter.create<CmpFOp>(loc, predicate, a, b);
}
Value FloatApproximationBuilder::select(Value predicate, Value a,
Value b) const {
return rewriter.create<SelectOp>(loc, predicate, a, b);
}
//----------------------------------------------------------------------------//
// TanhOp approximation.
//----------------------------------------------------------------------------//
namespace {
struct TanhApproximation : public OpRewritePattern<math::TanhOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(math::TanhOp op,
PatternRewriter &rewriter) const final;
};
} // namespace
LogicalResult
TanhApproximation::matchAndRewrite(math::TanhOp op,
PatternRewriter &rewriter) const {
if (!isValidFloatType(op.operand().getType()))
return rewriter.notifyMatchFailure(op, "unsupported operand type");
Value operand = op.operand();
FloatApproximationBuilder builder(op->getLoc(), operand.getType(), rewriter);
// Clamp operand into [plusClamp, minusClamp] range.
Value plusClamp = builder.constant(7.90531110763549805);
Value minusClamp = builder.constant(-7.9053111076354980);
Value x = builder.max(builder.min(operand, plusClamp), minusClamp);
// Mask for tiny values that are approximated with `operand`.
Value tiny = builder.constant(0.0004f);
Value tinyMask = builder.cmp(CmpFPredicate::OLT, builder.abs(operand), tiny);
// The monomial coefficients of the numerator polynomial (odd).
Value alpha1 = builder.constant(4.89352455891786e-03);
Value alpha3 = builder.constant(6.37261928875436e-04);
Value alpha5 = builder.constant(1.48572235717979e-05);
Value alpha7 = builder.constant(5.12229709037114e-08);
Value alpha9 = builder.constant(-8.60467152213735e-11);
Value alpha11 = builder.constant(2.00018790482477e-13);
Value alpha13 = builder.constant(-2.76076847742355e-16);
// The monomial coefficients of the denominator polynomial (even).
Value beta0 = builder.constant(4.89352518554385e-03);
Value beta2 = builder.constant(2.26843463243900e-03);
Value beta4 = builder.constant(1.18534705686654e-04);
Value beta6 = builder.constant(1.19825839466702e-06);
// Since the polynomials are odd/even, we need x^2.
Value x2 = builder.mul(x, x);
// Evaluate the numerator polynomial p.
Value p = builder.madd(x2, alpha13, alpha11);
p = builder.madd(x2, p, alpha9);
p = builder.madd(x2, p, alpha7);
p = builder.madd(x2, p, alpha5);
p = builder.madd(x2, p, alpha3);
p = builder.madd(x2, p, alpha1);
p = builder.mul(x, p);
// Evaluate the denominator polynomial q.
Value q = builder.madd(x2, beta6, beta4);
q = builder.madd(x2, q, beta2);
q = builder.madd(x2, q, beta0);
// Divide the numerator by the denominator.
Value res = builder.select(tinyMask, x, builder.div(p, q));
rewriter.replaceOp(op, res);
return success();
}
//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<TanhApproximation>(ctx);
}