
The greedy rewriter is used in many different flows and it has a lot of convenience (work list management, debugging actions, tracing, etc). But it combines two kinds of greedy behavior 1) how ops are matched, 2) folding wherever it can. These are independent forms of greedy and leads to inefficiency. E.g., cases where one need to create different phases in lowering and is required to applying patterns in specific order split across different passes. Using the driver one ends up needlessly retrying folding/having multiple rounds of folding attempts, where one final run would have sufficed. Of course folks can locally avoid this behavior by just building their own, but this is also a common requested feature that folks keep on working around locally in suboptimal ways. For downstream users, there should be no behavioral change. Updating from the deprecated should just be a find and replace (e.g., `find ./ -type f -exec sed -i 's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety) as the API arguments hasn't changed between the two.
286 lines
11 KiB
C++
286 lines
11 KiB
C++
//===- FoldTensorSubsetOps.cpp - Fold tensor subset ops -------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Fold tensor subset ops with producer / consumers.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
|
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include <type_traits>
|
|
|
|
namespace mlir {
|
|
namespace tensor {
|
|
#define GEN_PASS_DEF_FOLDTENSORSUBSETOPS
|
|
#include "mlir/Dialect/Tensor/Transforms/Passes.h.inc"
|
|
} // namespace tensor
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
static Value getTensorOperand(vector::TransferReadOp op) {
|
|
return op.getSource();
|
|
}
|
|
|
|
static Value getTensorOperand(tensor::InsertSliceOp op) {
|
|
return op.getSource();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// Merge extract_slice operation with load/transferRead operation.
|
|
class TransferReadOfExtractSliceOpFolder final
|
|
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
|
|
public:
|
|
using MaskableOpRewritePattern::MaskableOpRewritePattern;
|
|
|
|
FailureOr<mlir::Value>
|
|
matchAndRewriteMaskableOp(vector::TransferReadOp readOp,
|
|
vector::MaskingOpInterface maskOp,
|
|
PatternRewriter &rewriter) const override;
|
|
};
|
|
|
|
/// Merge insert_slice operation with store/transferWriteOp operation.
|
|
class InsertSliceOfTransferWriteOpFolder final
|
|
: public OpRewritePattern<tensor::InsertSliceOp> {
|
|
public:
|
|
using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertSliceOp,
|
|
PatternRewriter &rewriter) const override;
|
|
|
|
private:
|
|
static bool
|
|
doesTransferWriteCoverInsertSlice(vector::TransferWriteOp writeOp);
|
|
};
|
|
} // namespace
|
|
|
|
template <typename XferOp, typename ExtractOrInsertOp>
|
|
static LogicalResult preconditionsFoldExtractOrInsertWithTransferOp(
|
|
RewriterBase &rewriter, XferOp xferOp,
|
|
ExtractOrInsertOp extractOrInsertSliceOp) {
|
|
if (xferOp.hasOutOfBoundsDim())
|
|
return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
|
|
if (xferOp.getMask())
|
|
return rewriter.notifyMatchFailure(xferOp, "masked transfer");
|
|
if (!extractOrInsertSliceOp.hasUnitStride()) {
|
|
return rewriter.notifyMatchFailure(
|
|
xferOp, "non-1 stride insert/extract, requires keeping track of "
|
|
"strides, this may result in needing to insert "
|
|
"vector.insert_strided_slice/extract_strided_slice ops");
|
|
}
|
|
return success();
|
|
}
|
|
|
|
FailureOr<mlir::Value>
|
|
TransferReadOfExtractSliceOpFolder::matchAndRewriteMaskableOp(
|
|
vector::TransferReadOp readOp, vector::MaskingOpInterface maskOp,
|
|
PatternRewriter &rewriter) const {
|
|
auto extractSliceOp =
|
|
getTensorOperand(readOp).getDefiningOp<tensor::ExtractSliceOp>();
|
|
if (!extractSliceOp)
|
|
return rewriter.notifyMatchFailure(readOp, "not an extract_slice");
|
|
|
|
LogicalResult preconditionResult =
|
|
preconditionsFoldExtractOrInsertWithTransferOp(rewriter, readOp,
|
|
extractSliceOp);
|
|
if (failed(preconditionResult))
|
|
return rewriter.notifyMatchFailure(readOp, "Failed preconditions");
|
|
|
|
SmallVector<Value> indices(readOp.getIndices().begin(),
|
|
readOp.getIndices().end());
|
|
SmallVector<Value> sourceIndices;
|
|
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
|
|
rewriter, readOp.getLoc(), extractSliceOp.getMixedOffsets(),
|
|
extractSliceOp.getMixedStrides(), extractSliceOp.getDroppedDims(),
|
|
indices, sourceIndices);
|
|
|
|
Operation *newOp = rewriter.create<vector::TransferReadOp>(
|
|
readOp.getLoc(), readOp.getVectorType(), extractSliceOp.getSource(),
|
|
sourceIndices,
|
|
AffineMapAttr::get(expandDimsToRank(
|
|
readOp.getPermutationMap(), extractSliceOp.getSourceType().getRank(),
|
|
extractSliceOp.getDroppedDims())),
|
|
readOp.getPadding(),
|
|
/*mask=*/Value(), readOp.getInBoundsAttr());
|
|
if (maskOp)
|
|
newOp = mlir::vector::maskOperation(rewriter, newOp, maskOp.getMask());
|
|
return newOp->getResults()[0];
|
|
}
|
|
|
|
LogicalResult InsertSliceOfTransferWriteOpFolder::matchAndRewrite(
|
|
tensor::InsertSliceOp insertSliceOp, PatternRewriter &rewriter) const {
|
|
auto writeOp = getTensorOperand(insertSliceOp)
|
|
.template getDefiningOp<vector::TransferWriteOp>();
|
|
if (!writeOp)
|
|
return rewriter.notifyMatchFailure(insertSliceOp, "not a transfer_write");
|
|
|
|
LogicalResult preconditionResult =
|
|
preconditionsFoldExtractOrInsertWithTransferOp(rewriter, writeOp,
|
|
insertSliceOp);
|
|
if (failed(preconditionResult))
|
|
return preconditionResult;
|
|
|
|
if (!doesTransferWriteCoverInsertSlice(writeOp))
|
|
return rewriter.notifyMatchFailure(
|
|
insertSliceOp, "transfer_write does not cover insert_slice");
|
|
|
|
SmallVector<Value> indices(writeOp.getIndices().begin(),
|
|
writeOp.getIndices().end());
|
|
SmallVector<Value> sourceIndices;
|
|
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
|
|
rewriter, writeOp.getLoc(), insertSliceOp.getMixedOffsets(),
|
|
insertSliceOp.getMixedStrides(), insertSliceOp.getDroppedDims(), indices,
|
|
sourceIndices);
|
|
|
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
|
insertSliceOp, writeOp.getValue(), insertSliceOp.getDest(), sourceIndices,
|
|
AffineMapAttr::get(expandDimsToRank(writeOp.getPermutationMap(),
|
|
insertSliceOp.getDestType().getRank(),
|
|
insertSliceOp.getDroppedDims())),
|
|
writeOp.getInBoundsAttr());
|
|
|
|
return success();
|
|
}
|
|
|
|
bool InsertSliceOfTransferWriteOpFolder::doesTransferWriteCoverInsertSlice(
|
|
vector::TransferWriteOp writeOp) {
|
|
if (writeOp.getShapedType().hasStaticShape())
|
|
return llvm::equal(writeOp.getVectorType().getShape(),
|
|
writeOp.getShapedType().getShape());
|
|
|
|
// TODO: Use ValueBoundsConstraintSet for dynamic shapes.
|
|
|
|
return false;
|
|
}
|
|
|
|
template <typename OpTy>
|
|
struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
|
|
using OpRewritePattern<OpTy>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(OpTy insertSliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto sourceInsertSliceOp =
|
|
insertSliceOp.getSource()
|
|
.template getDefiningOp<tensor::InsertSliceOp>();
|
|
if (!sourceInsertSliceOp)
|
|
return failure();
|
|
|
|
// TODO: relax unit stride assumption where possible.
|
|
if (!insertSliceOp.hasUnitStride()) {
|
|
return rewriter.notifyMatchFailure(insertSliceOp,
|
|
"requires unit strides");
|
|
}
|
|
if (!sourceInsertSliceOp.hasUnitStride()) {
|
|
return rewriter.notifyMatchFailure(sourceInsertSliceOp,
|
|
"requires unit strides");
|
|
}
|
|
|
|
int64_t srcDim = 0;
|
|
llvm::SmallBitVector droppedDims = insertSliceOp.getDroppedDims();
|
|
for (int64_t d = 0, e = insertSliceOp.getDestType().getRank(); d < e; ++d) {
|
|
if (droppedDims[d])
|
|
continue;
|
|
if (insertSliceOp.getMixedSizes()[d] !=
|
|
sourceInsertSliceOp.getMixedSizes()[srcDim++]) {
|
|
return rewriter.notifyMatchFailure(
|
|
sourceInsertSliceOp,
|
|
"requires matching sizes to fold, otherwise a copy is needed");
|
|
}
|
|
}
|
|
|
|
// Resolve sizes according to dropped dims.
|
|
SmallVector<OpFoldResult> resolvedSizes;
|
|
// Note: the "insertSlice" case is symmetrical to the extract/subview case:
|
|
// `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
|
|
// passed as the destination to the helper function.
|
|
affine::resolveSizesIntoOpWithSizes(insertSliceOp.getMixedSizes(),
|
|
sourceInsertSliceOp.getMixedSizes(),
|
|
droppedDims, resolvedSizes);
|
|
|
|
// If we are inside an InParallel region, temporarily set the insertion
|
|
// point outside: only tensor.parallel_insert_slice ops are allowed in
|
|
// there.
|
|
if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
|
|
rewriter.setInsertionPoint(
|
|
insertSliceOp->template getParentOfType<scf::InParallelOp>());
|
|
}
|
|
|
|
// Resolve offsets according to source offsets and strides.
|
|
SmallVector<Value> resolvedOffsets;
|
|
// Note: the "insertSlice" case is symmetrical to the extract/subview case:
|
|
// `insertSliceOp` is passed as the "source" and `sourceInsertSliceOp` is
|
|
// passed as the destination to the helper function.
|
|
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
|
|
rewriter, insertSliceOp.getLoc(), insertSliceOp.getMixedOffsets(),
|
|
insertSliceOp.getMixedStrides(), droppedDims,
|
|
sourceInsertSliceOp.getMixedOffsets(), resolvedOffsets);
|
|
|
|
// Reset the insertion point.
|
|
rewriter.setInsertionPoint(insertSliceOp);
|
|
// Replace original op.
|
|
rewriter.replaceOpWithNewOp<OpTy>(
|
|
insertSliceOp, sourceInsertSliceOp.getSource(), insertSliceOp.getDest(),
|
|
getAsOpFoldResult(resolvedOffsets), resolvedSizes,
|
|
insertSliceOp.getMixedStrides());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void tensor::populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns) {
|
|
populateFoldTensorSubsetIntoVectorTransferPatterns(patterns);
|
|
patterns.add<InsertSliceOfInsertSliceFolder<tensor::InsertSliceOp>,
|
|
InsertSliceOfInsertSliceFolder<tensor::ParallelInsertSliceOp>>(
|
|
patterns.getContext());
|
|
}
|
|
|
|
void tensor::populateFoldTensorSubsetIntoVectorTransferPatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<TransferReadOfExtractSliceOpFolder,
|
|
InsertSliceOfTransferWriteOpFolder>(patterns.getContext());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pass registration
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
struct FoldTensorSubsetOpsPass final
|
|
: public tensor::impl::FoldTensorSubsetOpsBase<FoldTensorSubsetOpsPass> {
|
|
void runOnOperation() override;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void FoldTensorSubsetOpsPass::runOnOperation() {
|
|
RewritePatternSet patterns(&getContext());
|
|
tensor::populateFoldTensorSubsetOpPatterns(patterns);
|
|
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
|
|
return std::make_unique<FoldTensorSubsetOpsPass>();
|
|
}
|