The revision adds isOneInteger helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner. For downstream users, you can update the code with the below script. ```bash sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` --------- Signed-off-by: hanhanW <hanhan0912@gmail.com>
775 lines
35 KiB
C++
775 lines
35 KiB
C++
//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/LogicalResult.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tensor;
|
|
|
|
namespace {
|
|
/// Fold expand_shape(extract_slice) ops that cancel itself out.
|
|
struct FoldExpandOfRankReducingExtract
|
|
: public OpRewritePattern<ExpandShapeOp> {
|
|
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
RankedTensorType resultType = expandShapeOp.getResultType();
|
|
auto extractSliceOp =
|
|
expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
|
|
if (!extractSliceOp)
|
|
return failure();
|
|
RankedTensorType srcType = extractSliceOp.getSourceType();
|
|
|
|
// Only cases where the ExpandShapeOp can be folded away entirely are
|
|
// supported. Moreover, only simple cases where the resulting ExtractSliceOp
|
|
// has no rank-reduction anymore are supported at the moment.
|
|
RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
|
|
srcType, extractSliceOp.getStaticOffsets(),
|
|
extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
|
|
if (nonReducingExtractType != resultType)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
|
|
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
|
expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
|
|
mixedStrides);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold collapse_shape which only removes static dimensions of size `1`
|
|
/// into extract_slice.
|
|
struct FoldUnPaddingCollapseIntoExtract
|
|
: public OpRewritePattern<tensor::CollapseShapeOp> {
|
|
using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto extractSliceOp =
|
|
collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
|
|
// Collapse cannot be folded away with multiple users of the extract slice
|
|
// and it is not necessarily beneficial to only convert the collapse into
|
|
// another extract slice.
|
|
if (!extractSliceOp || !extractSliceOp->hasOneUse())
|
|
return failure();
|
|
|
|
// Only fold away simple collapse where all removed dimensions have static
|
|
// size `1`.
|
|
SliceVerificationResult res = isRankReducedType(
|
|
collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
|
|
if (res != SliceVerificationResult::Success)
|
|
return rewriter.notifyMatchFailure(collapseShapeOp,
|
|
"expected unpadding collapse");
|
|
|
|
Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
|
|
extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
|
|
extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
|
|
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
|
rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold insert_slice(collapse_shape) ops that cancel itself out.
|
|
template <typename OpTy>
|
|
struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy insertSliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto collapseShapeOp =
|
|
insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
|
|
if (!collapseShapeOp)
|
|
return failure();
|
|
RankedTensorType srcType = collapseShapeOp.getSrcType();
|
|
|
|
// Only cases where the CollapseShapeOp can be folded away entirely are
|
|
// supported. Moreover, only simple cases where the resulting InsertSliceOp
|
|
// has no rank-reduction anymore are supported at the moment.
|
|
RankedTensorType nonReducingInsertType =
|
|
RankedTensorType::get(insertSliceOp.getStaticSizes(),
|
|
insertSliceOp.getDestType().getElementType());
|
|
if (nonReducingInsertType != srcType)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
|
|
rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
|
|
insertSliceOp.getDest(), mixedOffsets,
|
|
mixedSizes, mixedStrides);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold expand_shape which only adds static dimensions of size `1`
|
|
/// into insert_slice.
|
|
template <typename OpTy>
|
|
struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy insertSliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto expandShapeOp = insertSliceOp.getSource()
|
|
.template getDefiningOp<tensor::ExpandShapeOp>();
|
|
if (!expandShapeOp)
|
|
return failure();
|
|
|
|
// Only fold away simple expansion where all added dimensions have static
|
|
// size `1`.
|
|
SliceVerificationResult res = isRankReducedType(
|
|
expandShapeOp.getResultType(), expandShapeOp.getSrcType());
|
|
if (res != SliceVerificationResult::Success)
|
|
return rewriter.notifyMatchFailure(insertSliceOp,
|
|
"expected rank increasing expansion");
|
|
|
|
rewriter.modifyOpInPlace(insertSliceOp, [&]() {
|
|
insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc());
|
|
});
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Pattern to bubble up a tensor.expand_shape op through a producer
|
|
/// tensor.collapse_shape op that has non intersecting reassociations.
|
|
struct BubbleUpExpandThroughParallelCollapse
|
|
: public OpRewritePattern<tensor::ExpandShapeOp> {
|
|
using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto collapseOp =
|
|
expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
|
|
if (!collapseOp)
|
|
return failure();
|
|
auto expandReInds = expandOp.getReassociationIndices();
|
|
auto collapseReInds = collapseOp.getReassociationIndices();
|
|
|
|
// Special case where the collapsed tensor to expand is a 0-D tensor,
|
|
// then the reassociation maps will be empty and not produce valid results.
|
|
if (expandReInds.size() == 0) {
|
|
return failure();
|
|
}
|
|
|
|
// Reshapes are parallel to each other (by construction the number of
|
|
// reassociations specified in the collapse and expand are the same), if at
|
|
// any position
|
|
// 1. either the reassociation indices are of the same size, or
|
|
// 2. either the reassociation in the collapse or the expand is of size 1.
|
|
ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
|
|
ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
|
|
for (auto [expandReassociation, collapseReassociation] :
|
|
llvm::zip_equal(expandReInds, collapseReInds)) {
|
|
if (collapseReassociation.size() == expandReassociation.size()) {
|
|
// Even if the reassociations are the same, the collapse/expand should
|
|
// result in the same dimensions. i.e 4x8x2 into 64 should be expanded
|
|
// into 4x8x2 again. In presense of dynamic dimensions one can only
|
|
// verify "equality" when there is only one dynamic dimension present,
|
|
// and all other static dimensions are equal.
|
|
ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
|
|
collapseReassociation.front(), collapseReassociation.size());
|
|
int64_t numCollapsedDynamic =
|
|
llvm::count_if(collapsedStaticShapes,
|
|
[](int64_t d) { return ShapedType::isDynamic(d); });
|
|
ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
|
|
expandReassociation.front(), expandReassociation.size());
|
|
int64_t numExpandedDynamic =
|
|
llvm::count_if(expandedStaticShapes,
|
|
[](int64_t d) { return ShapedType::isDynamic(d); });
|
|
if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
|
|
collapsedStaticShapes != expandedStaticShapes) {
|
|
return failure();
|
|
}
|
|
continue;
|
|
}
|
|
// If the reassociations are not same, one or the other needs to be of
|
|
// size one.
|
|
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
|
|
return failure();
|
|
}
|
|
|
|
// Compute new reassociation indices and expanded/collaped shapes.
|
|
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
|
|
Location loc = expandOp->getLoc();
|
|
SmallVector<OpFoldResult> sourceSizes =
|
|
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
|
|
SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
|
|
SmallVector<OpFoldResult> newExpandSizes;
|
|
|
|
int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
|
|
resultSizeIndex = 0;
|
|
|
|
for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) {
|
|
auto &collapseReassociation = collapseReInds[idx];
|
|
auto &expandReassociation = expandReInds[idx];
|
|
|
|
// Case 1. The reassociations are same in the collapse producer
|
|
// and expand consumer. In the swapped expand, each of the final
|
|
// dimensions are kept as is in the expand and the collapse. So,
|
|
// for every element in the `ReassocationIndices` vector add a new
|
|
// `ReassociationIndices` vector for the swapped expand and collapse
|
|
// (of size 1).
|
|
if (collapseReassociation.size() == expandReassociation.size()) {
|
|
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
|
|
newCollapseReInds.push_back({newCollapseIndex++});
|
|
newExpandReInds.push_back({newExpandIndex++});
|
|
newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
|
|
sourceSizeIndex++;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and
|
|
// in the expand is of size == 1). In this case, the original dimensions
|
|
// are preserved on expansion and collapsed subsequently.
|
|
if (collapseReassociation.size() != 1) {
|
|
ReassociationIndices newCollapseReassociation;
|
|
for (size_t i = 0; i < collapseReassociation.size(); ++i) {
|
|
newCollapseReassociation.push_back(newCollapseIndex++);
|
|
newExpandReInds.push_back({newExpandIndex++});
|
|
newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]);
|
|
}
|
|
resultSizeIndex++;
|
|
newCollapseReInds.push_back(newCollapseReassociation);
|
|
continue;
|
|
}
|
|
|
|
// Case 3. The `ReassociationIndices` in the expand is of size > 1 (and
|
|
// in the collapse is of size == 1). In this case, the expansion happens
|
|
// first and the expanded dimensions are preserved on collapse.
|
|
ReassociationIndices newExpandReassociation;
|
|
for (size_t i = 0; i < expandReassociation.size(); ++i) {
|
|
newExpandReassociation.push_back(newExpandIndex++);
|
|
newCollapseReInds.push_back({newCollapseIndex++});
|
|
newExpandSizes.push_back(resultSizes[resultSizeIndex++]);
|
|
}
|
|
newExpandReInds.push_back(newExpandReassociation);
|
|
sourceSizeIndex++;
|
|
}
|
|
|
|
// Swap reshape order.
|
|
SmallVector<Value> dynamicSizes;
|
|
SmallVector<int64_t> staticSizes;
|
|
dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
|
|
auto expandResultType = expandOp.getResultType().clone(staticSizes);
|
|
Value newCollapseSrc = collapseOp.getSrc();
|
|
// If the number of reassociation indices in the new `expand_shape` op
|
|
// matches the number of dimensions of the result, then the expand_shape
|
|
// is a no-op.
|
|
if (newExpandReInds.size() != newExpandSizes.size()) {
|
|
newCollapseSrc = rewriter.create<tensor::ExpandShapeOp>(
|
|
loc, expandResultType, newCollapseSrc, newExpandReInds,
|
|
newExpandSizes);
|
|
}
|
|
|
|
// If the number of reassociation indices in the new `collapse_shape` op
|
|
// matches the number of dimensions of the source, then the collapse_shape
|
|
// is a no-op.
|
|
Value replacement = newCollapseSrc;
|
|
if (newCollapseReInds.size() != newExpandSizes.size()) {
|
|
replacement = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, newCollapseSrc, newCollapseReInds);
|
|
}
|
|
rewriter.replaceOp(expandOp, replacement);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts `tensor.extract_slice(tensor.expand_shape)` to
|
|
/// `tensor.expand_shape(tensor.extract_slice)`.
|
|
///
|
|
/// For this transformation to be possible, the slice must be fully contiguous
|
|
/// within each reassociation group of the expand_shape. A slice is defined as
|
|
/// fully contiguous within a reassociation group if after flattening the
|
|
/// reassociation group to a single 1D range, then the slice taken out of the
|
|
/// group could be defined as a single contiguous subrange within that range.
|
|
///
|
|
/// Rank reducing slices are not supported.
|
|
///
|
|
/// Example:
|
|
/// The transformation is possible because each reassociation group has a
|
|
/// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]).
|
|
/// ```
|
|
/// BEFORE:
|
|
/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
|
|
/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
|
|
/// %slice = tensor.extract_slice %reshape ...
|
|
/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
|
|
///
|
|
/// AFTER:
|
|
/// %slice = tensor.extract_slice %in ...
|
|
/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
|
|
/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
|
|
/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
|
|
/// ```
|
|
///
|
|
/// Note - this pattern could be extended to be a swap pattern between
|
|
/// `tensor.expand_shape` and `tensor.extract_slice`, but is currently
|
|
/// implemented only as a bubble up pattern for `tensor.extract_slice`.
|
|
struct BubbleUpExpandShapeThroughExtractSlice
|
|
: public OpRewritePattern<tensor::ExtractSliceOp> {
|
|
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto expandShapeOp =
|
|
sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
|
|
|
|
if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp,
|
|
rewriter)
|
|
.failed())
|
|
return failure();
|
|
|
|
// The tensor.extract_slice before applying the pattern works on the result
|
|
// of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp)
|
|
// referring to the state before applying the pattern are named with the
|
|
// prefix "expanded", and ones referring to the state after applying the
|
|
// pattern are named with the prefix "collapsed".
|
|
SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes();
|
|
SmallVector<OpFoldResult> expandedShape =
|
|
getMixedValues(expandShapeOp.getStaticOutputShape(),
|
|
expandShapeOp.getOutputShape(), rewriter);
|
|
|
|
// Helper variables and function for accumulating the size values.
|
|
Location loc = expandShapeOp->getLoc();
|
|
AffineExpr d0, d1, d2;
|
|
bindDims(rewriter.getContext(), d0, d1, d2);
|
|
// Multiply two integers.
|
|
auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
|
|
auto mulMap = AffineMap::get(2, 0, {d0 * d1});
|
|
return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
|
|
{v1, v2});
|
|
};
|
|
|
|
// Compute new offsets, sizes, and strides for tensor.extract_slice.
|
|
// The new tensor.extract_slice will work on a tensor that has has a rank of
|
|
// ReassociationIndices.size(). In the loop a single offset, size, and
|
|
// stride value is computed per reassociation group.
|
|
SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes,
|
|
collapsedStrides;
|
|
for (const ReassociationIndices &indices :
|
|
expandShapeOp.getReassociationIndices()) {
|
|
// collapsedSize will hold the size of the single dim that represents the
|
|
// reassociation group in the non expanded tensor.
|
|
OpFoldResult collapsedSize = rewriter.getIndexAttr(1);
|
|
// The reassocGroupSizes and reassocGroupOffsets are used to create an
|
|
// affine.linearize_index op to linearize the single offset value required
|
|
// for this reassociation group.
|
|
SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets;
|
|
|
|
for (long expandedDim : indices) {
|
|
// reassocGroupSizes and reassocGroupOffsets can be obtained directly
|
|
// from the expanded state, but the collapsed size requires calculation
|
|
// as it did not previously exist.
|
|
reassocGroupSizes.push_back(expandedShape[expandedDim]);
|
|
reassocGroupOffsets.push_back(expandedOffsets[expandedDim]);
|
|
collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]);
|
|
}
|
|
|
|
SmallVector<Value> offsetVals =
|
|
llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) {
|
|
return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
|
|
});
|
|
OpFoldResult collapsedOffset =
|
|
rewriter
|
|
.create<affine::AffineLinearizeIndexOp>(loc, offsetVals,
|
|
reassocGroupSizes,
|
|
/*disjoint=*/true)
|
|
.getResult();
|
|
collapsedOffsets.push_back(collapsedOffset);
|
|
collapsedSizes.push_back(collapsedSize);
|
|
|
|
// Only unit stride is supported.
|
|
collapsedStrides.push_back(rewriter.getIndexAttr(1));
|
|
}
|
|
|
|
// The shape of the result can be obtained from the sizes passed in.
|
|
SmallVector<Value> dynDims;
|
|
SmallVector<int64_t> shape;
|
|
dispatchIndexOpFoldResults(expandedSizes, dynDims, shape);
|
|
RankedTensorType resultType = RankedTensorType::get(
|
|
shape, expandShapeOp.getResultType().getElementType());
|
|
|
|
// Create a new ExtractSliceOp and ExpandShapeOp.
|
|
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes,
|
|
collapsedStrides);
|
|
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
|
sliceOp, resultType, newSliceOp,
|
|
expandShapeOp.getReassociationIndices(), expandedSizes);
|
|
return success();
|
|
}
|
|
|
|
// Helper function to check if all the required conditions for the
|
|
// tensor.extract_slice to be bubbled up through the tensor.expand_shape are
|
|
// met.
|
|
LogicalResult
|
|
checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp,
|
|
tensor::ExpandShapeOp expandShapeOp,
|
|
PatternRewriter &rewriter) const {
|
|
|
|
if (!expandShapeOp) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "tensor.extract_slice source not produced by expand_shape");
|
|
}
|
|
|
|
if (!sliceOp.hasUnitStride()) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
|
|
"be supported in this transformation.");
|
|
}
|
|
|
|
SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
|
|
|
|
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
|
|
sizes.size()) {
|
|
return rewriter.notifyMatchFailure(sliceOp,
|
|
"unimplemented: rank reducing slice");
|
|
}
|
|
|
|
SmallVector<OpFoldResult> outputShape =
|
|
getMixedValues(expandShapeOp.getStaticOutputShape(),
|
|
expandShapeOp.getOutputShape(), rewriter);
|
|
|
|
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
|
|
isZeroOffsetAndFullSize =
|
|
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
|
|
if (!isZeroInteger(offset))
|
|
return false;
|
|
FailureOr<bool> maybeEqual =
|
|
ValueBoundsConstraintSet::areEqual(sliceSize, size);
|
|
return llvm::succeeded(maybeEqual) && maybeEqual.value();
|
|
};
|
|
|
|
// Check that the slice is contiguous within each reassociation group.
|
|
// The slice is contiguous only if after the first dimension where a non
|
|
// unit slice is taken, the slice size on all subsequent dimensions of the
|
|
// group is equal to the entire size of the dimension.
|
|
// Examples of contiguous slices:
|
|
// full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10]
|
|
// full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10]
|
|
// Examples of non contiguous slices:
|
|
// full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5]
|
|
// full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5]
|
|
for (const ReassociationIndices &indices :
|
|
expandShapeOp.getReassociationIndices()) {
|
|
int64_t i = 0;
|
|
int64_t e = indices.size();
|
|
// Find the first expanded dim after the first dim with non-unit extracted
|
|
// size.
|
|
for (; i < e; ++i) {
|
|
if (!isOneInteger(sizes[indices[i]])) {
|
|
// +1 to skip the first non-unit size dim.
|
|
i++;
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Verify that all subsequent dimensions extract the full size of the
|
|
// source tensor.
|
|
for (; i < e; ++i) {
|
|
int64_t expandedDim = indices[i];
|
|
if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
|
|
outputShape[expandedDim])) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "Not a contiguous slice of the expanded tensor.");
|
|
}
|
|
}
|
|
}
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Converts `tensor.extract_slice(tensor.collapse_shape)` to
|
|
/// `tensor.collapse_shape(tensor.extract_slice)`.
|
|
///
|
|
/// For this transformation to be possible - after bubbling up, the extraction
|
|
/// of the contiguous slice must be representable as a single slice obtained via
|
|
/// tensor.extract_slice within each reassociation group of the src.
|
|
///
|
|
/// In case the size and offset extracted are static then this is possible if
|
|
/// the following conditions are met within each reassociation group:
|
|
/// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the
|
|
/// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the
|
|
/// shape of a desired slice. A slice of shape S can be extracted as a
|
|
/// contiguous span of elements if and only if there exists an index k in {0, 1,
|
|
/// ..., n} such that:
|
|
/// S_i = 1 for all i < k (that is, all leading dimensions are singleton),
|
|
/// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly
|
|
/// one dimension),
|
|
/// S_i = A_i for all i > k (that is, all trailing dimensions are preserved
|
|
/// in full).
|
|
/// In other words, the slice shape S must be of the form:
|
|
/// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ]
|
|
///
|
|
/// In case the size and/or offset extracted are dynamic then this is possible
|
|
/// only if there is single dimension in the reassociation group that has a size
|
|
/// not equal to 1.
|
|
/// In other words, the tensor shape must be of the form:
|
|
/// [ 1, 1, ..., 1, A, 1, ...,1 ]
|
|
/// Note - it might be possible to enable this pattern for more cases when the
|
|
/// size/offset are dynamic via performing an analysis of the possible values
|
|
/// that could be given to the size/offset.
|
|
///
|
|
/// Example:
|
|
/// The transformation is possible because each reassociation group can be
|
|
/// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?],
|
|
/// [20->10]).
|
|
/// ```
|
|
/// BEFORE:
|
|
/// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ...
|
|
/// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32>
|
|
/// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1]
|
|
/// tensor<128x7x20xf32> to tensor<32x?x10xf32>
|
|
///
|
|
/// AFTER:
|
|
/// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10]
|
|
// [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32>
|
|
/// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ...
|
|
/// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32>
|
|
/// ```
|
|
///
|
|
/// Negative example:
|
|
/// The transformation is not possible because we cannot use a single slice to
|
|
/// represent the reassociation group [2x3x10->???]. If we would want the
|
|
/// collapse to be after the extraction, we would need to extract multiple
|
|
/// slices and concat them together.
|
|
/// ```
|
|
/// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into
|
|
/// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] :
|
|
/// tensor<60xf32> to tensor<15xf32>
|
|
/// ```
|
|
/// If we would want the collapse to be after the extraction, a possible
|
|
/// alternate transformation could be to extract multiple slices and concat them
|
|
/// together:
|
|
/// ```
|
|
/// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] :
|
|
/// tensor<2x3x10xf32> to tensor <1x1x10xf32>
|
|
/// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] :
|
|
/// tensor<2x3x10xf32> to tensor <1x1x5xf32>
|
|
/// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} :
|
|
/// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32>
|
|
/// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32>
|
|
/// to tensor<15xf32>
|
|
/// ```
|
|
/// But this is not the intended purpose of the transformation.
|
|
struct BubbleUpCollapseShapeThroughExtractSlice
|
|
: public OpRewritePattern<tensor::ExtractSliceOp> {
|
|
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto collapseShapeOp =
|
|
sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
|
|
if (!collapseShapeOp) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp,
|
|
"tensor.extract_slice source not produced by tensor.collapse_shape");
|
|
}
|
|
|
|
if (!sliceOp.hasUnitStride()) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "unsupported: non-unit stride. Only contiguous slices can "
|
|
"be supported in this transformation.");
|
|
}
|
|
|
|
// The tensor.extract_slice before applying the pattern works on the result
|
|
// of the tensor.collapse_shape, so variables (i.e. inputs for
|
|
// ExtractSliceOp) referring to the state before applying the pattern are
|
|
// named with the prefix "collapsed", and ones referring to the state after
|
|
// applying the pattern are named with the prefix "expanded".
|
|
SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets();
|
|
SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes();
|
|
|
|
if (static_cast<size_t>(sliceOp.getResultType().getRank()) !=
|
|
collapsedSizes.size()) {
|
|
return rewriter.notifyMatchFailure(sliceOp,
|
|
"unimplemented: rank reducing slice");
|
|
}
|
|
|
|
ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape();
|
|
SmallVector<ReassociationIndices, 4> reassociationIndices =
|
|
collapseShapeOp.getReassociationIndices();
|
|
|
|
// Compute new offsets, sizes, and strides for tensor.extract_slice.
|
|
// The new tensor.extract_slice will work on a tensor that has has a rank
|
|
// equal to the rank of the src of the collapse_shape. In each iteration of
|
|
// the loop, the offsets and sizes will be computed per reassociation group.
|
|
SmallVector<OpFoldResult> expandedOffsets, expandedSizes;
|
|
SmallVector<OpFoldResult> expandedStrides(srcShape.size(),
|
|
rewriter.getIndexAttr(1));
|
|
|
|
for (auto [collapsedSize, collapsedOffset, reassocIndices] :
|
|
llvm::zip_equal(collapsedSizes, collapsedOffsets,
|
|
collapseShapeOp.getReassociationIndices())) {
|
|
// CASE #1 - size and/or offset are dynamic.
|
|
// In this case, the slice can be represented as a contiguous slice only
|
|
// if there is a single dimension in the reassociation group that has a
|
|
// size not equal to 1.
|
|
if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) {
|
|
int nonUnitSizeCount = 0;
|
|
for (int64_t expandedShapeIdx : reassocIndices) {
|
|
if (srcShape[expandedShapeIdx] != 1) {
|
|
nonUnitSizeCount++;
|
|
expandedSizes.push_back(collapsedSize);
|
|
expandedOffsets.push_back(collapsedOffset);
|
|
continue;
|
|
}
|
|
|
|
expandedSizes.push_back(rewriter.getIndexAttr(1));
|
|
expandedOffsets.push_back(rewriter.getIndexAttr(0));
|
|
}
|
|
|
|
if (nonUnitSizeCount != 1) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp,
|
|
"unsupported: slice cannot be verified to be contiguous");
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// CASE #2 = size and offset are static.
|
|
// Verify that the slice can be represented as a contiguous slice of the
|
|
// src of the collapse_shape.
|
|
// Checking this is done on order of most internal dimensions first,
|
|
// so traversal is done in reverse order of the reassociation group.
|
|
// If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2,
|
|
// ...,An] then we first find the size and offset for n...k+1 then for k
|
|
// and then for k-1...0.
|
|
|
|
// currentCollapsedsize and currentCollapsedOffset are initialized with
|
|
// the original collapsed size and offset and divided by the expanded
|
|
// shape size in each dimension as we go along the reassociation group.
|
|
// In essence we are spreading the original collapsed size and offset over
|
|
// the various expanded slice dimensions.
|
|
// The variables are used both to check the validity of the slice and to
|
|
// compute the expanded sizes and offsets.
|
|
int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value();
|
|
int64_t currentCollapsedOffset =
|
|
getConstantIntValue(collapsedOffset).value();
|
|
|
|
SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets;
|
|
|
|
ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(),
|
|
reassocIndices.rend());
|
|
int64_t idx = 0;
|
|
int64_t reassocGroupSize = reassocIndices.size();
|
|
|
|
// First handle the trailing dimensions where the slice size should be
|
|
// equal to the tensor shape and the offset should be 0 (n...k+1).
|
|
for (; idx < reassocGroupSize; ++idx) {
|
|
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
|
|
|
|
if (currentCollapsedsize < expandedShapeSize)
|
|
break;
|
|
|
|
// We need to make sure that the slice size can be set to the shape size
|
|
// and the offset to 0.
|
|
if ((currentCollapsedsize % expandedShapeSize) != 0 ||
|
|
(currentCollapsedOffset % expandedShapeSize) != 0) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "unsupported: cannot be extracted as a contiguous slice "
|
|
"of the src of the collapse_shape");
|
|
}
|
|
|
|
groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize));
|
|
groupExpandedOffsets.push_back(rewriter.getIndexAttr(0));
|
|
|
|
currentCollapsedsize /= expandedShapeSize;
|
|
currentCollapsedOffset /= expandedShapeSize;
|
|
}
|
|
|
|
// Now handle the first dim where slicing occurs on (k).
|
|
if (idx < reassocGroupSize) {
|
|
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
|
|
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
|
|
// We need to make sure that the slice size in this dim + offset will
|
|
// not exceed the shape size.
|
|
if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) {
|
|
return rewriter.notifyMatchFailure(
|
|
sliceOp, "unsupported: slice cannot be extracted as a contiguous "
|
|
"slice of the src of the collapse_shape");
|
|
}
|
|
|
|
groupExpandedSizes.push_back(
|
|
rewriter.getIndexAttr(currentCollapsedsize));
|
|
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
|
|
|
|
currentCollapsedOffset /= expandedShapeSize;
|
|
}
|
|
|
|
// Now handle the leading dimensions where the slice size is equal to 1
|
|
// (k-1...0).
|
|
// The size for these dimensions must be 1 because of how we constructed
|
|
// the slice size of the expanded shape. We spread the original collapsed
|
|
// size over the expanded shape sizes until we reached dimension k where
|
|
// the remaining size was smaller than the expanded shape size, and spread
|
|
// the remaining size on it. So, now we are left with only 1s.
|
|
for (idx++; idx < reassocGroupSize; ++idx) {
|
|
int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]];
|
|
int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize;
|
|
groupExpandedSizes.push_back(rewriter.getIndexAttr(1));
|
|
groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim));
|
|
currentCollapsedOffset /= expandedShapeSize;
|
|
}
|
|
|
|
expandedSizes.append(groupExpandedSizes.rbegin(),
|
|
groupExpandedSizes.rend());
|
|
expandedOffsets.append(groupExpandedOffsets.rbegin(),
|
|
groupExpandedOffsets.rend());
|
|
}
|
|
|
|
Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
|
|
collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), expandedOffsets,
|
|
expandedSizes, expandedStrides);
|
|
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
|
|
sliceOp, sliceOp.getResultType(), newSliceOp,
|
|
collapseShapeOp.getReassociationIndices());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns
|
|
.add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
|
|
FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
|
|
FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
|
|
FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
|
|
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
void mlir::tensor::populateBubbleUpExpandShapePatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
|
|
}
|
|
|
|
void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<BubbleUpExpandShapeThroughExtractSlice,
|
|
BubbleUpCollapseShapeThroughExtractSlice>(patterns.getContext());
|
|
}
|