llvm-project/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
Lei Zhang 39c80656fe [mlir][vector] Convert extract_strided_slice to extract & insert chain
This is useful for breaking down extract_strided_slice and potentially
cancel with other extract / insert ops before or after.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D137471
2022-11-09 19:42:07 -05:00

359 lines
14 KiB
C++

//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
auto vectorType = into.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
auto vectorType = vector.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have different ranks.
///
/// When ranks are different, InsertStridedSlice needs to extract a properly
/// ranked vector from the destination vector into which to insert. This pattern
/// only takes care of this extraction part and forwards the rest to
/// [ConvertSameRankInsertStridedSliceIntoShuffle].
///
/// For a k-D source and n-D destination vector (k < n), we emit:
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
/// insert the k-D source.
/// 2. k-D -> (n-1)-D InsertStridedSlice op
/// 3. InsertOp that is the reverse of 1.
class DecomposeDifferentRankInsertStridedSlice
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.getOffsets().getValue().empty())
return failure();
auto loc = op.getLoc();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff == 0)
return failure();
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
// on it.
Value extracted = rewriter.create<ExtractOp>(
loc, op.getDest(),
getI64SubArray(op.getOffsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
// A different pattern will kick in for InsertStridedSlice with matching
// ranks.
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
loc, op.getSource(), extracted,
getI64SubArray(op.getOffsets(), /*dropFront=*/rankDiff),
getI64SubArray(op.getStrides(), /*dropFront=*/0));
rewriter.replaceOpWithNewOp<InsertOp>(
op, stridedSliceInnerOp.getResult(), op.getDest(),
getI64SubArray(op.getOffsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
return success();
}
};
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have the same rank. For each outermost index in the slice:
/// begin end stride
/// [offset : offset+size*stride : stride]
/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
/// 3. the destination subvector is inserted back in the proper place
/// 3. InsertOp that is the reverse of 1.
class ConvertSameRankInsertStridedSliceIntoShuffle
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
void initialize() {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
// bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.getOffsets().getValue().empty())
return failure();
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
assert(dstRank >= srcRank);
if (dstRank != srcRank)
return failure();
if (srcType == dstType) {
rewriter.replaceOp(op, op.getSource());
return success();
}
int64_t offset =
op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = srcType.getShape().front();
int64_t stride =
op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
Value res = op.getDest();
if (srcRank == 1) {
int nSrc = srcType.getShape().front();
int nDest = dstType.getShape().front();
// 1. Scale source to destType so we can shufflevector them together.
SmallVector<int64_t> offsets(nDest, 0);
for (int64_t i = 0; i < nSrc; ++i)
offsets[i] = i;
Value scaledSource = rewriter.create<ShuffleOp>(loc, op.getSource(),
op.getSource(), offsets);
// 2. Create a mask where we take the value from scaledSource of dest
// depending on the offset.
offsets.clear();
for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
if (i < offset || i >= e || (i - offset) % stride != 0)
offsets.push_back(nDest + i);
else
offsets.push_back((i - offset) / stride);
}
// 3. Replace with a ShuffleOp.
rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.getDest(),
offsets);
return success();
}
// For each slice of the source vector along the most major dimension.
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource = extractOne(rewriter, loc, op.getSource(), idx);
if (extractedSource.getType().isa<VectorType>()) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
Value extractedDest = extractOne(rewriter, loc, op.getDest(), off);
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
extractedSource = rewriter.create<InsertStridedSliceOp>(
loc, extractedSource, extractedDest,
getI64SubArray(op.getOffsets(), /* dropFront=*/1),
getI64SubArray(op.getStrides(), /* dropFront=*/1));
}
// 4. Insert the extractedSource into the res vector.
res = insertOne(rewriter, loc, extractedSource, res, off);
}
rewriter.replaceOp(op, res);
return success();
}
};
/// RewritePattern for ExtractStridedSliceOp where source and destination
/// vectors are 1-D. For such cases, we can lower it to a ShuffleOp.
class Convert1DExtractStridedSliceIntoShuffle
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size =
op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
int64_t stride =
op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
assert(dstType.getElementType().isSignlessIntOrIndexOrFloat());
// Single offset can be more efficiently shuffled.
if (op.getOffsets().getValue().size() != 1)
return failure();
SmallVector<int64_t, 4> offsets;
offsets.reserve(size);
for (int64_t off = offset, e = offset + size * stride; off < e;
off += stride)
offsets.push_back(off);
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
op.getVector(),
rewriter.getI64ArrayAttr(offsets));
return success();
}
};
/// For a 1-D ExtractStridedSlice, breaks it down into a chain of Extract ops
/// to extract each element from the source, and then a chain of Insert ops
/// to insert to the target vector.
class Convert1DExtractStridedSliceIntoExtractInsertChain final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
Convert1DExtractStridedSliceIntoExtractInsertChain(
MLIRContext *context,
std::function<bool(ExtractStridedSliceOp)> controlFn,
PatternBenefit benefit)
: OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
if (controlFn && !controlFn(op))
return failure();
// Only handle 1-D cases.
if (op.getOffsets().getValue().size() != 1)
return failure();
int64_t offset =
op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size =
op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
int64_t stride =
op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
Location loc = op.getLoc();
SmallVector<Value> elements;
elements.reserve(size);
for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
elements.push_back(rewriter.create<ExtractOp>(loc, op.getVector(), i));
Value result = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(op.getType()));
for (int64_t i = 0; i < size; ++i)
result = rewriter.create<InsertOp>(loc, elements[i], result, i);
rewriter.replaceOp(op, result);
return success();
}
private:
std::function<bool(ExtractStridedSliceOp)> controlFn;
};
/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D.
/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower
/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case.
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
void initialize() {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
// is bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
op.getOffsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size =
op.getSizes().getValue().front().cast<IntegerAttr>().getInt();
int64_t stride =
op.getStrides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
// Single offset can be more efficiently shuffled. It's handled in
// Convert1DExtractStridedSliceIntoShuffle.
if (op.getOffsets().getValue().size() == 1)
return failure();
// Extract/insert on a lower ranked extract strided slice op.
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = extractOne(rewriter, loc, op.getVector(), off);
Value extracted = rewriter.create<ExtractStridedSliceOp>(
loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
getI64SubArray(op.getSizes(), /* dropFront=*/1),
getI64SubArray(op.getStrides(), /* dropFront=*/1));
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
return success();
}
};
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
RewritePatternSet &patterns,
std::function<bool(ExtractStridedSliceOp)> controlFn,
PatternBenefit benefit) {
patterns.add<Convert1DExtractStridedSliceIntoExtractInsertChain>(
patterns.getContext(), std::move(controlFn), benefit);
}
/// Populate the given list with patterns that convert from Vector to LLVM.
void vector::populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns, PatternBenefit benefit) {
populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns,
benefit);
patterns.add<ConvertSameRankInsertStridedSliceIntoShuffle,
Convert1DExtractStridedSliceIntoShuffle>(patterns.getContext(),
benefit);
}