
Summary: Progressive lowering of vector.transpose into an operation that is closer to an intrinsic, and thus the hardware ISA. Currently under the common vector transform testing flag, as we prepare deploying this transformation in the LLVM lowering pipeline. Reviewers: nicolasvasilache, reidtatge, andydavis1, ftynse Reviewed By: nicolasvasilache, ftynse Subscribers: mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, llvm-commits Tags: #llvm, #mlir Differential Revision: https://reviews.llvm.org/D80772
104 lines
3.9 KiB
C++
104 lines
3.9 KiB
C++
//===- TestVectorToVectorConversion.cpp - Test VectorTransfers lowering ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <type_traits>
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::vector;
|
|
namespace {
|
|
|
|
#include "TestVectorTransformPatterns.h.inc"
|
|
|
|
struct TestVectorToVectorConversion
|
|
: public PassWrapper<TestVectorToVectorConversion, FunctionPass> {
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
auto *context = &getContext();
|
|
populateWithGenerated(context, &patterns);
|
|
populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
|
populateVectorToVectorTransformationPatterns(patterns, context);
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
struct TestVectorSlicesConversion
|
|
: public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
populateVectorSlicesLoweringPatterns(patterns, &getContext());
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
struct TestVectorContractionConversion
|
|
: public PassWrapper<TestVectorContractionConversion, FunctionPass> {
|
|
TestVectorContractionConversion() = default;
|
|
TestVectorContractionConversion(const TestVectorContractionConversion &pass) {
|
|
}
|
|
|
|
Option<bool> lowerToFlatMatrix{
|
|
*this, "vector-lower-matrix-intrinsics",
|
|
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToFlatTranspose{
|
|
*this, "vector-flat-transpose",
|
|
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToOuterProduct{
|
|
*this, "vector-outerproduct",
|
|
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnFunction() override {
|
|
OwningRewritePatternList patterns;
|
|
if (lowerToOuterProduct) {
|
|
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
|
|
VectorTransformsOptions options{lowering};
|
|
patterns.insert<ContractionOpToOuterProductOpLowering>(options,
|
|
&getContext());
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
return;
|
|
}
|
|
|
|
VectorContractLowering contractLowering = VectorContractLowering::FMA;
|
|
if (lowerToFlatMatrix)
|
|
contractLowering = VectorContractLowering::Matmul;
|
|
VectorTransposeLowering transposeLowering =
|
|
VectorTransposeLowering::EltWise;
|
|
if (lowerToFlatTranspose)
|
|
transposeLowering = VectorTransposeLowering::Flat;
|
|
VectorTransformsOptions options{contractLowering, transposeLowering};
|
|
populateVectorContractLoweringPatterns(patterns, &getContext(), options);
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
}
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
namespace mlir {
|
|
void registerTestVectorConversions() {
|
|
PassRegistration<TestVectorToVectorConversion> vectorToVectorPass(
|
|
"test-vector-to-vector-conversion",
|
|
"Test conversion patterns between ops in the vector dialect");
|
|
|
|
PassRegistration<TestVectorSlicesConversion> slicesPass(
|
|
"test-vector-slices-conversion",
|
|
"Test conversion patterns that lower slices ops in the vector dialect");
|
|
|
|
PassRegistration<TestVectorContractionConversion> contractionPass(
|
|
"test-vector-contraction-conversion",
|
|
"Test conversion patterns that lower contract ops in the vector dialect");
|
|
}
|
|
} // namespace mlir
|