
Default vector.contract lowering essentially yields a series of sdot/ddot operations. However, for some layouts a series of saxpy/daxpy operations, chained through fma are more efficient. This CL introduces a choice between the two lowering paths. A default heuristic is to follow. Some preliminary avx2 performance numbers for matrix-times-vector. Here, dot performs best for 64x64 A x b and saxpy for 64x64 A^T x b. ``` ------------------------------------------------------------ A x b A^T x b ------------------------------------------------------------ GFLOPS sdot (reassoc) saxpy sdot (reassoc) saxpy ------------------------------------------------------------ 1x1 0.6 0.9 0.6 0.9 2x2 2.5 3.2 2.4 3.5 4x4 6.4 8.4 4.9 11.8 8x8 11.7 6.1 5.0 29.6 16x16 20.7 10.8 7.3 43.3 32x32 29.3 7.9 6.4 51.8 64x64 38.9 79.3 128x128 32.4 40.7 ------------------------------------------------------------ ``` Reviewed By: nicolasvasilache, ftynse Differential Revision: https://reviews.llvm.org/D83012
112 lines
4.2 KiB
C++
112 lines
4.2 KiB
C++
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <type_traits>
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
namespace {
|
|
|
|
#include "TestVectorTransformPatterns.h.inc"
|
|
|
|
struct TestVectorToVectorConversion
|
|
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
auto *context = &getContext();
|
|
populateWithGenerated(context, &patterns);
|
|
populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
|
populateVectorToVectorTransformationPatterns(patterns, context);
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
struct TestVectorSlicesConversion
|
|
: public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
populateVectorSlicesLoweringPatterns(patterns, &getContext());
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
struct TestVectorContractionConversion
|
|
: public PassWrapper<TestVectorContractionConversion, FunctionPass> {
|
|
TestVectorContractionConversion() = default;
|
|
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
|
|
}
|
|
|
|
Option<bool> lowerToFlatMatrix{
|
|
*this, "vector-lower-matrix-intrinsics",
|
|
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToFlatTranspose{
|
|
*this, "vector-flat-transpose",
|
|
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToOuterProduct{
|
|
*this, "vector-outerproduct",
|
|
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToAXPY{*this, "vector-axpy",
|
|
llvm::cl::desc("Lower vector.contract to AXPY"),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
|
|
// Test on one pattern in isolation.
|
|
if (lowerToOuterProduct) {
|
|
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
|
|
VectorTransformsOptions options{lowering};
|
|
patterns.insert<ContractionOpToOuterProductOpLowering>(options,
|
|
&getContext());
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
return;
|
|
}
|
|
|
|
// Test on all contract lowering patterns.
|
|
VectorContractLowering contractLowering = VectorContractLowering::Dot;
|
|
if (lowerToFlatMatrix)
|
|
contractLowering = VectorContractLowering::Matmul;
|
|
else if (lowerToAXPY)
|
|
contractLowering = VectorContractLowering::AXPY;
|
|
VectorTransposeLowering transposeLowering =
|
|
VectorTransposeLowering::EltWise;
|
|
if (lowerToFlatTranspose)
|
|
transposeLowering = VectorTransposeLowering::Flat;
|
|
VectorTransformsOptions options{contractLowering, transposeLowering};
|
|
populateVectorContractLoweringPatterns(patterns, &getContext(), options);
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
namespace mlir {
|
|
void registerTestVectorConversions() {
|
|
PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
|
|
"test-vector-to-vector-conversion",
|
|
"Test conversion patterns between ops in the vector dialect");
|
|
|
|
PassRegistration<TestVectorSlicesConversion> slicesPass(
|
|
"test-vector-slices-conversion",
|
|
"Test conversion patterns that lower slices ops in the vector dialect");
|
|
|
|
PassRegistration<TestVectorContractionConversion> contractionPass(
|
|
"test-vector-contraction-conversion",
|
|
"Test conversion patterns that lower contract ops in the vector dialect");
|
|
}
|
|
} // namespace mlir
|