llvm-project/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
Eugene Zhulenev f99ccf6516 [mlir] Add math polynomial approximation pass
This gives ~30x speedup compared to expanding Tanh into exp operations:

```
name                  old cpu/op  new cpu/op  delta
BM_mlir_Tanh_f32/10    253ns ± 3%    55ns ± 7%  -78.35%  (p=0.000 n=44+41)
BM_mlir_Tanh_f32/100  2.21µs ± 4%  0.14µs ± 8%  -93.85%  (p=0.000 n=48+49)
BM_mlir_Tanh_f32/1k   22.6µs ± 4%   0.7µs ± 5%  -96.68%  (p=0.000 n=32+42)
BM_mlir_Tanh_f32/10k   225µs ± 5%     7µs ± 6%  -96.88%  (p=0.000 n=49+55)

name                  old time/op             new time/op             delta
BM_mlir_Tanh_f32/10    259ns ± 1%               56ns ± 2%  -78.31%        (p=0.000 n=41+39)
BM_mlir_Tanh_f32/100  2.27µs ± 1%             0.14µs ± 5%  -93.89%        (p=0.000 n=46+49)
BM_mlir_Tanh_f32/1k   22.9µs ± 1%              0.8µs ± 4%  -96.67%        (p=0.000 n=30+42)
BM_mlir_Tanh_f32/10k   230µs ± 0%                7µs ± 3%  -96.88%        (p=0.000 n=37+55)
```

This approximations is based on Eigen::generic_fast_tanh function

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D96739
2021-02-19 12:43:36 -08:00

47 lines
1.6 KiB
C++

//===- TestPolynomialApproximation.cpp - Test math ops approximations -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains test passes for expanding math operations into
// polynomial approximations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestMathPolynomialApproximationPass
: public PassWrapper<TestMathPolynomialApproximationPass, FunctionPass> {
void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect, math::MathDialect>();
}
};
} // end anonymous namespace
void TestMathPolynomialApproximationPass::runOnFunction() {
OwningRewritePatternList patterns;
populateMathPolynomialApproximationPatterns(patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
namespace mlir {
namespace test {
void registerTestMathPolynomialApproximationPass() {
PassRegistration<TestMathPolynomialApproximationPass> pass(
"test-math-polynomial-approximation",
"Test math polynomial approximations");
}
} // namespace test
} // namespace mlir