[mlir][vector] Add vector.to_elements unrolling (#157142)
This PR adds support for unrolling `vector.to_element`'s source operand. It transforms ```mlir %0:8 = vector.to_elements %v : vector<2x2x2xf32> ``` to ```mlir %v0 = vector.extract %v[0] : vector<2x2xf32> from vector<2x2x2xf32> %v1 = vector.extract %v[1] : vector<2x2xf32> from vector<2x2x2xf32> %0:4 = vector.to_elements %v0 : vector<2x2xf32> %1:4 = vector.to_elements %v1 : vector<2x2xf32> // %0:8 = %0:4 - %1:4 ``` This pattern will be applied until there are only 1-D vectors left. --------- Signed-off-by: hanhanW <hanhan0912@gmail.com> Co-authored-by: hanhanW <hanhan0912@gmail.com> Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
This commit is contained in:
parent
ddb2e34334
commit
9d19250610
@ -265,6 +265,17 @@ def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyUnrollToElementsPatternsOp : Op<Transform_Dialect,
|
||||
"apply_patterns.vector.unroll_to_elements",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
let description = [{
|
||||
Indicates that vector to_elements operations should be unrolled
|
||||
along the outermost dimension.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
|
||||
"apply_patterns.vector.lower_scan",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
|
||||
@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
|
||||
void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [UnrollToElements]
|
||||
void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit = 1);
|
||||
|
||||
/// Populate the pattern set with the following patterns:
|
||||
///
|
||||
/// [ContractionOpToMatmulOpLowering]
|
||||
|
||||
@ -255,6 +255,12 @@ using UnrollVectorOpFn =
|
||||
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
|
||||
UnrollVectorOpFn unrollFn);
|
||||
|
||||
/// Generic utility for unrolling values of type vector<NxAxBx...>
|
||||
/// to N values of type vector<AxBx...> using vector.extract. If the input
|
||||
/// is rank-1 or has leading scalable dimension, failure is returned.
|
||||
FailureOr<SmallVector<Value>> unrollVectorValue(TypedValue<VectorType>,
|
||||
RewriterBase &);
|
||||
|
||||
} // namespace vector
|
||||
|
||||
/// Constructs a permutation map of invariant memref indices to vector
|
||||
|
||||
@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
|
||||
populateVectorRankReducingFMAPattern(patterns);
|
||||
populateVectorGatherLoweringPatterns(patterns);
|
||||
populateVectorFromElementsLoweringPatterns(patterns);
|
||||
populateVectorToElementsLoweringPatterns(patterns);
|
||||
if (armI8MM) {
|
||||
if (armNeon)
|
||||
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
|
||||
|
||||
@ -144,6 +144,11 @@ void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
|
||||
vector::populateVectorFromElementsLoweringPatterns(patterns);
|
||||
}
|
||||
|
||||
void transform::ApplyUnrollToElementsPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
vector::populateVectorToElementsLoweringPatterns(patterns);
|
||||
}
|
||||
|
||||
void transform::ApplyLowerScanPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
vector::populateVectorScanLoweringPatterns(patterns);
|
||||
|
||||
@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
|
||||
LowerVectorScan.cpp
|
||||
LowerVectorShapeCast.cpp
|
||||
LowerVectorStep.cpp
|
||||
LowerVectorToElements.cpp
|
||||
LowerVectorToFromElementsToShuffleTree.cpp
|
||||
LowerVectorTransfer.cpp
|
||||
LowerVectorTranspose.cpp
|
||||
|
||||
53
mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
Normal file
53
mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
Normal file
@ -0,0 +1,53 @@
|
||||
//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements target-independent rewrites and utilities to lower the
|
||||
// 'vector.to_elements' operation.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
|
||||
|
||||
#define DEBUG_TYPE "lower-vector-to-elements"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ToElementsOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
TypedValue<VectorType> source = op.getSource();
|
||||
FailureOr<SmallVector<Value>> result =
|
||||
vector::unrollVectorValue(source, rewriter);
|
||||
if (failed(result)) {
|
||||
return failure();
|
||||
}
|
||||
SmallVector<Value> vectors = *result;
|
||||
|
||||
SmallVector<Value> results;
|
||||
for (const Value &vector : vectors) {
|
||||
auto subElements =
|
||||
vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
|
||||
llvm::append_range(results, subElements.getResults());
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorToElementsLoweringPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit benefit) {
|
||||
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
|
||||
}
|
||||
@ -393,6 +393,41 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Takes a 2+ dimensional vector as an input
|
||||
/// returns n vector values produced by n vector.extract operations.
|
||||
/// I.e. calling unrollVectorValue([[%v]], rewriter) such that
|
||||
///
|
||||
/// %v : vector<nxaxb...>
|
||||
///
|
||||
/// will produce the following IR changes
|
||||
///
|
||||
/// %v0 = vector.extract %v[0] : vector<axbx...> from vector<nxaxb...>
|
||||
/// %v1 = vector.extract %v[1] : vector<axbx...> from vector<nxaxb...>
|
||||
/// ...
|
||||
/// %vnminusone = vector.extract %v[n-1] : vector<axbx...> from ...
|
||||
///
|
||||
/// and returns SmallVector<Value> r = {[[%v0]], [[%v1]], ..., [[%vnminusone]]}
|
||||
FailureOr<SmallVector<Value>>
|
||||
vector::unrollVectorValue(TypedValue<VectorType> vector,
|
||||
RewriterBase &rewriter) {
|
||||
SmallVector<Value> subvectors;
|
||||
VectorType ty = cast<VectorType>(vector.getType());
|
||||
Location loc = vector.getLoc();
|
||||
if (ty.getRank() < 2)
|
||||
return rewriter.notifyMatchFailure(loc, "already 1-D");
|
||||
|
||||
// Unrolling doesn't take vscale into account. Pattern is disabled for
|
||||
// vectors with leading scalable dim(s).
|
||||
if (ty.getScalableDims().front())
|
||||
return rewriter.notifyMatchFailure(loc, "cannot unroll scalable dim");
|
||||
|
||||
for (int64_t i = 0, e = ty.getShape().front(); i < e; ++i) {
|
||||
subvectors.push_back(vector::ExtractOp::create(rewriter, loc, vector, i));
|
||||
}
|
||||
|
||||
return subvectors;
|
||||
}
|
||||
|
||||
LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
|
||||
vector::UnrollVectorOpFn unrollFn) {
|
||||
assert(op->getNumResults() == 1 && "expected single result");
|
||||
|
||||
@ -1774,3 +1774,45 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
|
||||
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
|
||||
return %0 : vector<2x1x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.to_elements
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @to_elements_1d(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
|
||||
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
|
||||
// CHECK: %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
|
||||
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
|
||||
// CHECK: return %[[V0]], %[[V1]]
|
||||
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
|
||||
%0:2 = vector.to_elements %arg0 : vector<2xf32>
|
||||
return %0#0, %0#1 : f32, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// NOTE: We unroll multi-dimensional to_elements ops with pattern
|
||||
// `UnrollToElements` and then convert the 1-D to_elements ops to llvm.
|
||||
|
||||
// CHECK-LABEL: func @to_elements_2d(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
|
||||
// CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
|
||||
// CHECK: %[[V0:.+]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<2 x vector<2xf32>>
|
||||
// CHECK: %[[V1:.+]] = llvm.extractvalue %[[CAST]][1] : !llvm.array<2 x vector<2xf32>>
|
||||
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
|
||||
// CHECK: %[[R0:.+]] = llvm.extractelement %[[V0]][%[[C0]] : i64] : vector<2xf32>
|
||||
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: %[[R1:.+]] = llvm.extractelement %[[V0]][%[[C1]] : i64] : vector<2xf32>
|
||||
// CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
|
||||
// CHECK: %[[R2:.+]] = llvm.extractelement %[[V1]][%[[C0]] : i64] : vector<2xf32>
|
||||
// CHECK: %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
|
||||
// CHECK: %[[R3:.+]] = llvm.extractelement %[[V1]][%[[C1]] : i64] : vector<2xf32>
|
||||
// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[R3]]
|
||||
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
|
||||
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
|
||||
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
|
||||
}
|
||||
|
||||
2
mlir/test/Dialect/Vector/lit.local.cfg
Normal file
2
mlir/test/Dialect/Vector/lit.local.cfg
Normal file
@ -0,0 +1,2 @@
|
||||
# Skip the directory with input TD sequences.
|
||||
config.excludes = ["td"]
|
||||
11
mlir/test/Dialect/Vector/td/unroll-elements.mlir
Normal file
11
mlir/test/Dialect/Vector/td/unroll-elements.mlir
Normal file
@ -0,0 +1,11 @@
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @unroll_to_elements(%module_op: !transform.any_op {transform.readonly}) {
|
||||
%f = transform.structured.match ops{["func.func"]} in %module_op
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %f {
|
||||
transform.apply_patterns.vector.transfer_permutation_patterns
|
||||
transform.apply_patterns.vector.unroll_to_elements
|
||||
} : !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
26
mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
Normal file
26
mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
Normal file
@ -0,0 +1,26 @@
|
||||
// RUN: mlir-opt %s -test-unroll-vector-to-elements -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -transform-preload-library='transform-library-paths=%p/td/unroll-elements.mlir' \
|
||||
// RUN: -transform-interpreter=entry-point=unroll_to_elements | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @to_elements_1d(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: vector<2xf32>
|
||||
// CHECK: %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
|
||||
// CHECK: return %[[RES]]#0, %[[RES]]#1
|
||||
func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
|
||||
%0:2 = vector.to_elements %arg0 : vector<2xf32>
|
||||
return %0#0, %0#1 : f32, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @to_elements_2d(
|
||||
// CHECK-SAME: %[[ARG0:.+]]: vector<2x2xf32>
|
||||
// CHECK: %[[VEC0:.+]] = vector.extract %[[ARG0]][0] : vector<2xf32> from vector<2x2xf32>
|
||||
// CHECK: %[[VEC1:.+]] = vector.extract %[[ARG0]][1] : vector<2xf32> from vector<2x2xf32>
|
||||
// CHECK: %[[RES0:.+]]:2 = vector.to_elements %[[VEC0]] : vector<2xf32>
|
||||
// CHECK: %[[RES1:.+]]:2 = vector.to_elements %[[VEC1]] : vector<2xf32>
|
||||
// CHECK: return %[[RES0]]#0, %[[RES0]]#1, %[[RES1]]#0, %[[RES1]]#1
|
||||
func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
|
||||
%0:4 = vector.to_elements %arg0 : vector<2x2xf32>
|
||||
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
|
||||
}
|
||||
@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
|
||||
}
|
||||
};
|
||||
|
||||
struct TestUnrollVectorToElements
|
||||
: public PassWrapper<TestUnrollVectorToElements,
|
||||
OperationPass<func::FuncOp>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorToElements)
|
||||
|
||||
StringRef getArgument() const final {
|
||||
return "test-unroll-vector-to-elements";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test unrolling patterns for to_elements ops";
|
||||
}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<func::FuncDialect, vector::VectorDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorToElementsLoweringPatterns(patterns);
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
struct TestFoldArithExtensionIntoVectorContractPatterns
|
||||
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
|
||||
OperationPass<func::FuncOp>> {
|
||||
@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {
|
||||
|
||||
PassRegistration<TestUnrollVectorFromElements>();
|
||||
|
||||
PassRegistration<TestUnrollVectorToElements>();
|
||||
|
||||
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
|
||||
|
||||
PassRegistration<TestVectorEmulateMaskedLoadStore>();
|
||||
|
||||
@ -48,6 +48,8 @@ def non_configurable_patterns():
|
||||
vector.ApplyLowerGatherPatternsOp()
|
||||
# CHECK: transform.apply_patterns.vector.unroll_from_elements
|
||||
vector.ApplyUnrollFromElementsPatternsOp()
|
||||
# CHECK: transform.apply_patterns.vector.unroll_to_elements
|
||||
vector.ApplyUnrollToElementsPatternsOp()
|
||||
# CHECK: transform.apply_patterns.vector.lower_scan
|
||||
vector.ApplyLowerScanPatternsOp()
|
||||
# CHECK: transform.apply_patterns.vector.lower_shape_cast
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user