llvm-project/mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
Quentin Colombet d0aeb74e88 [mlir][MemRef] Simplify extract_strided_metadata(expand_shape)
Add a pattern to the pass that simplifies
extract_strided_metadata(other_op(memref)).

The new pattern gets rid of the expand_shape operation while
materializing its effects on the sizes, and the strides of
the base object.

In other words, this simplification replaces:
```
baseBuffer, offset, sizes, strides =
             extract_strided_metadata(expand_shape(memref))
```

With

```
baseBuffer, offset, baseSizes, baseStrides =
    extract_strided_metadata(memref)
sizes#reassIdx =
    baseSizes#reassDim / product(expandShapeSizes#j,
                                 for j in group excluding
                                   reassIdx)
strides#reassIdx =
    baseStrides#reassDim * product(expandShapeSizes#j,
                                   for j in
                                     reassIdx+1..
                                       reassIdx+group.size-1)
```

Where `reassIdx` is a reassociation index for the group at
`reassDim` and `expandShapeSizes#j` is either:
- The constant size at dimension j, derived directly from the
  result type of the expand_shape op, or
- An affine expression: baseSizes#reassDim / product of all
  constant sizes in expandShapeSizes.

Note: baseBuffer and offset are unaffected by the expand_shape
operation.

Differential Revision: https://reviews.llvm.org/D133625
2022-09-22 19:07:09 +00:00

464 lines
19 KiB
C++

//===- SimplifyExtractStridedMetadata.cpp - Simplify this operation -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// This pass simplifies extract_strided_metadata(other_op(memref) to
/// extract_strided_metadata(memref) when it is possible to express the effect
// of other_op using affine apply on the results of
// extract_strided_metadata(memref).
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallBitVector.h"
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_SIMPLIFYEXTRACTSTRIDEDMETADATA
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
using namespace mlir;
namespace {
/// Replace `baseBuffer, offset, sizes, strides =
/// extract_strided_metadata(subview(memref, subOffset,
/// subSizes, subStrides))`
/// With
///
/// \verbatim
/// baseBuffer, baseOffset, baseSizes, baseStrides =
/// extract_strided_metadata(memref)
/// strides#i = baseStrides#i * subSizes#i
/// offset = baseOffset + sum(subOffset#i * strides#i)
/// sizes = subSizes
/// \endverbatim
///
/// In other words, get rid of the subview in that expression and canonicalize
/// on its effects on the offset, the sizes, and the strides using affine.apply.
struct ExtractStridedMetadataOpSubviewFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
public:
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto subview = op.getSource().getDefiningOp<memref::SubViewOp>();
if (!subview)
return failure();
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(subview(memref)).
Location origLoc = op.getLoc();
IndexType indexType = rewriter.getIndexType();
Value source = subview.getSource();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(
origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
sizeStrideTypes, source);
SmallVector<int64_t> sourceStrides;
int64_t sourceOffset;
bool hasKnownStridesAndOffset =
succeeded(getStridesAndOffset(sourceType, sourceStrides, sourceOffset));
(void)hasKnownStridesAndOffset;
assert(hasKnownStridesAndOffset &&
"getStridesAndOffset must work on valid subviews");
// Compute the new strides and offset from the base strides and offset:
// newStride#i = baseStride#i * subStride#i
// offset = baseOffset + sum(subOffsets#i * newStrides#i)
SmallVector<OpFoldResult> strides;
SmallVector<OpFoldResult> subStrides = subview.getMixedStrides();
auto origStrides = newExtractStridedMetadata.getStrides();
// Hold the affine symbols and values for the computation of the offset.
SmallVector<OpFoldResult> values(3 * sourceRank + 1);
SmallVector<AffineExpr> symbols(3 * sourceRank + 1);
detail::bindSymbolsList(rewriter.getContext(), symbols);
AffineExpr expr = symbols.front();
values[0] = ShapedType::isDynamicStrideOrOffset(sourceOffset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
: rewriter.getIndexAttr(sourceOffset);
SmallVector<OpFoldResult> subOffsets = subview.getMixedOffsets();
AffineExpr s0 = rewriter.getAffineSymbolExpr(0);
AffineExpr s1 = rewriter.getAffineSymbolExpr(1);
for (unsigned i = 0; i < sourceRank; ++i) {
// Compute the stride.
OpFoldResult origStride =
ShapedType::isDynamicStrideOrOffset(sourceStrides[i])
? origStrides[i]
: OpFoldResult(rewriter.getIndexAttr(sourceStrides[i]));
strides.push_back(makeComposedFoldedAffineApply(
rewriter, origLoc, s0 * s1, {subStrides[i], origStride}));
// Build up the computation of the offset.
unsigned baseIdxForDim = 1 + 3 * i;
unsigned subOffsetForDim = baseIdxForDim;
unsigned subStrideForDim = baseIdxForDim + 1;
unsigned origStrideForDim = baseIdxForDim + 2;
expr = expr + symbols[subOffsetForDim] * symbols[subStrideForDim] *
symbols[origStrideForDim];
values[subOffsetForDim] = subOffsets[i];
values[subStrideForDim] = subStrides[i];
values[origStrideForDim] = origStride;
}
// Compute the offset.
OpFoldResult finalOffset =
makeComposedFoldedAffineApply(rewriter, origLoc, expr, values);
SmallVector<Value> results;
// The final result is <baseBuffer, offset, sizes, strides>.
// Thus we need 1 + 1 + subview.getRank() + subview.getRank(), to hold all
// the values.
auto subType = subview.getType().cast<MemRefType>();
unsigned subRank = subType.getRank();
// Properly size the array so that we can do random insertions
// at the right indices.
// We do that to populate the non-dropped sizes and strides in one go.
results.resize_for_overwrite(subRank * 2 + 2);
results[0] = newExtractStridedMetadata.getBaseBuffer();
results[1] =
getValueOrCreateConstantIndexOp(rewriter, origLoc, finalOffset);
// The sizes of the final type are defined directly by the input sizes of
// the subview.
// Moreover subviews can drop some dimensions, some strides and sizes may
// not end up in the final <base, offset, sizes, strides> value that we are
// replacing.
// Do the filtering here.
SmallVector<OpFoldResult> subSizes = subview.getMixedSizes();
const unsigned sizeStartIdx = 2;
const unsigned strideStartIdx = sizeStartIdx + subRank;
unsigned insertedDims = 0;
llvm::SmallBitVector droppedDims = subview.getDroppedDims();
for (unsigned i = 0; i < sourceRank; ++i) {
if (droppedDims.test(i))
continue;
results[sizeStartIdx + insertedDims] =
getValueOrCreateConstantIndexOp(rewriter, origLoc, subSizes[i]);
results[strideStartIdx + insertedDims] =
getValueOrCreateConstantIndexOp(rewriter, origLoc, strides[i]);
++insertedDims;
}
assert(insertedDims == subRank &&
"Should have populated all the values at this point");
rewriter.replaceOp(op, results);
return success();
}
};
/// Compute the expanded sizes of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origSizes hold the sizes of the source shape as values.
/// This is used to compute the new sizes in cases of dynamic shapes.
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
static SmallVector<OpFoldResult>
getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
assert(!reassocGroup.empty() &&
"Reassociation group should have at least one dimension");
unsigned groupSize = reassocGroup.size();
SmallVector<OpFoldResult> expandedSizes(groupSize);
uint64_t productOfAllStaticSizes = 1;
Optional<unsigned> dynSizeIdx;
MemRefType expandShapeType = expandShape.getResultType();
// Fill up all the statically known sizes.
for (unsigned i = 0; i < groupSize; ++i) {
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
if (ShapedType::isDynamic(dimSize)) {
assert(!dynSizeIdx && "There must be at most one dynamic size per group");
dynSizeIdx = i;
continue;
}
productOfAllStaticSizes *= dimSize;
expandedSizes[i] = builder.getIndexAttr(dimSize);
}
// Compute the dynamic size using the original size and all the other known
// static sizes:
// expandSize = origSize / productOfAllStaticSizes.
if (dynSizeIdx) {
AffineExpr s0 = builder.getAffineSymbolExpr(0);
expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
origSizes[groupId]);
}
return expandedSizes;
}
/// Compute the expanded strides of the given \p expandShape for the
/// \p groupId-th reassociation group.
/// \p origStrides and \p origSizes hold respectively the strides and sizes
/// of the source shape as values.
/// This is used to compute the strides in cases of dynamic shapes and/or
/// dynamic stride for this reassociation group.
///
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
///
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
/// this is not possible because this function uses the Affine dialect and the
/// MemRef dialect cannot depend on the Affine dialect.
SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
OpBuilder &builder,
ArrayRef<OpFoldResult> origSizes,
ArrayRef<OpFoldResult> origStrides,
unsigned groupId) {
SmallVector<int64_t, 2> reassocGroup =
expandShape.getReassociationIndices()[groupId];
assert(!reassocGroup.empty() &&
"Reassociation group should have at least one dimension");
unsigned groupSize = reassocGroup.size();
MemRefType expandShapeType = expandShape.getResultType();
Optional<int64_t> dynSizeIdx;
// Fill up the expanded strides, with the information we can deduce from the
// resulting shape.
uint64_t currentStride = 1;
SmallVector<OpFoldResult> expandedStrides(groupSize);
for (int i = groupSize - 1; i >= 0; --i) {
expandedStrides[i] = builder.getIndexAttr(currentStride);
uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
if (ShapedType::isDynamic(dimSize)) {
assert(!dynSizeIdx && "There must be at most one dynamic size per group");
dynSizeIdx = i;
continue;
}
currentStride *= dimSize;
}
// Collect the statically known information about the original stride.
Value source = expandShape.getSrc();
auto sourceType = source.getType().cast<MemRefType>();
SmallVector<int64_t> strides;
int64_t offset;
bool hasKnownStridesAndOffset =
succeeded(getStridesAndOffset(sourceType, strides, offset));
(void)hasKnownStridesAndOffset;
assert(hasKnownStridesAndOffset &&
"getStridesAndOffset must work on valid expand_shape");
OpFoldResult origStride =
ShapedType::isDynamicStrideOrOffset(strides[groupId])
? origStrides[groupId]
: builder.getIndexAttr(strides[groupId]);
// Apply the original stride to all the strides.
int64_t doneStrideIdx = 0;
// If we saw a dynamic dimension, we need to fix-up all the strides up to
// that dimension with the dynamic size.
if (dynSizeIdx) {
int64_t productOfAllStaticSizes = currentStride;
assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
"We shouldn't be able to change dynamicity");
OpFoldResult origSize = origSizes[groupId];
AffineExpr s0 = builder.getAffineSymbolExpr(0);
AffineExpr s1 = builder.getAffineSymbolExpr(1);
for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
.get<Attribute>()
.cast<IntegerAttr>()
.getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(),
(s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
{origSize, origStride});
}
}
// Now apply the origStride to the remaining dimensions.
AffineExpr s0 = builder.getAffineSymbolExpr(0);
for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
int64_t baseExpandedStride = expandedStrides[doneStrideIdx]
.get<Attribute>()
.cast<IntegerAttr>()
.getInt();
expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
}
return expandedStrides;
}
/// Replace `baseBuffer, offset, sizes, strides =
/// extract_strided_metadata(expand_shape(memref))`
/// With
///
/// \verbatim
/// baseBuffer, offset, baseSizes, baseStrides =
/// extract_strided_metadata(memref)
/// sizes#reassIdx =
/// baseSizes#reassDim / product(expandShapeSizes#j,
/// for j in group excluding reassIdx)
/// strides#reassIdx =
/// baseStrides#reassDim * product(expandShapeSizes#j, for j in
/// reassIdx+1..reassIdx+group.size-1)
/// \endverbatim
///
/// Where reassIdx is a reassociation index for the group at reassDim
/// and expandShapeSizes#j is either:
/// - The constant size at dimension j, derived directly from the result type of
/// the expand_shape op, or
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
/// element.)
///
/// Notice that `baseBuffer` and `offset` are unchanged.
///
/// In other words, get rid of the expand_shape in that expression and
/// materialize its effects on the sizes and the strides using affine apply.
struct ExtractStridedMetadataOpExpandShapeFolder
: public OpRewritePattern<memref::ExtractStridedMetadataOp> {
public:
using OpRewritePattern<memref::ExtractStridedMetadataOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op,
PatternRewriter &rewriter) const override {
auto expandShape = op.getSource().getDefiningOp<memref::ExpandShapeOp>();
if (!expandShape)
return failure();
// Build a plain extract_strided_metadata(memref) from
// extract_strided_metadata(expand_shape(memref)).
Location origLoc = op.getLoc();
IndexType indexType = rewriter.getIndexType();
Value source = expandShape.getSrc();
auto sourceType = source.getType().cast<MemRefType>();
unsigned sourceRank = sourceType.getRank();
SmallVector<Type> sizeStrideTypes(sourceRank, indexType);
auto newExtractStridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(
origLoc, op.getBaseBuffer().getType(), indexType, sizeStrideTypes,
sizeStrideTypes, source);
// Collect statically known information.
SmallVector<int64_t> strides;
int64_t offset;
bool hasKnownStridesAndOffset =
succeeded(getStridesAndOffset(sourceType, strides, offset));
(void)hasKnownStridesAndOffset;
assert(hasKnownStridesAndOffset &&
"getStridesAndOffset must work on valid expand_shape");
MemRefType expandShapeType = expandShape.getResultType();
unsigned expandShapeRank = expandShapeType.getRank();
// The result value will start with the base_buffer and offset.
unsigned baseIdxInResult = 2;
SmallVector<OpFoldResult> results(baseIdxInResult + expandShapeRank * 2);
results[0] = newExtractStridedMetadata.getBaseBuffer();
results[1] = ShapedType::isDynamicStrideOrOffset(offset)
? getAsOpFoldResult(newExtractStridedMetadata.getOffset())
: rewriter.getIndexAttr(offset);
// Get the special case of 0-D out of the way.
if (sourceRank == 0) {
Value constantOne = getValueOrCreateConstantIndexOp(
rewriter, origLoc, rewriter.getIndexAttr(1));
SmallVector<Value> resultValues(baseIdxInResult + expandShapeRank * 2,
constantOne);
for (unsigned i = 0; i < baseIdxInResult; ++i)
resultValues[i] =
getValueOrCreateConstantIndexOp(rewriter, origLoc, results[i]);
rewriter.replaceOp(op, resultValues);
return success();
}
// Compute the expanded strides and sizes from the base strides and sizes.
SmallVector<OpFoldResult> origSizes =
getAsOpFoldResult(newExtractStridedMetadata.getSizes());
SmallVector<OpFoldResult> origStrides =
getAsOpFoldResult(newExtractStridedMetadata.getStrides());
unsigned idx = 0, endIdx = expandShape.getReassociationIndices().size();
for (; idx != endIdx; ++idx) {
SmallVector<OpFoldResult> expandedSizes =
getExpandedSizes(expandShape, rewriter, origSizes, /*groupId=*/idx);
SmallVector<OpFoldResult> expandedStrides = getExpandedStrides(
expandShape, rewriter, origSizes, origStrides, /*groupId=*/idx);
unsigned groupSize = expandShape.getReassociationIndices()[idx].size();
const unsigned sizeStartIdx = baseIdxInResult;
const unsigned strideStartIdx = sizeStartIdx + expandShapeRank;
for (unsigned i = 0; i < groupSize; ++i) {
results[sizeStartIdx + i] = expandedSizes[i];
results[strideStartIdx + i] = expandedStrides[i];
}
baseIdxInResult += groupSize;
}
assert(idx == sourceRank &&
"We should have visited all the input dimensions");
assert(baseIdxInResult == expandShapeRank + 2 &&
"We should have populated all the values");
rewriter.replaceOp(
op, getValueOrCreateConstantIndexOp(rewriter, origLoc, results));
return success();
}
};
} // namespace
void memref::populateSimplifyExtractStridedMetadataOpPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractStridedMetadataOpSubviewFolder,
ExtractStridedMetadataOpExpandShapeFolder>(
patterns.getContext());
}
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
struct SimplifyExtractStridedMetadataPass final
: public memref::impl::SimplifyExtractStridedMetadataBase<
SimplifyExtractStridedMetadataPass> {
void runOnOperation() override;
};
} // namespace
void SimplifyExtractStridedMetadataPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateSimplifyExtractStridedMetadataOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}
std::unique_ptr<Pass> memref::createSimplifyExtractStridedMetadataPass() {
return std::make_unique<SimplifyExtractStridedMetadataPass>();
}