[mlir][NFC] Move around the code related to PatternRewriting to improve layering

There are several pieces of pattern rewriting infra in IR/ that really shouldn't be there. This revision moves those pieces to a better location such that they are easier to evolve in the future(e.g. with PDL). More concretely this revision does the following:

* Create a Transforms/GreedyPatternRewriteDriver.h and move the apply*andFold methods there.
The definitions for these methods are already in Transforms/ so it doesn't make sense for the declarations to be in IR.

* Create a new lib/Rewrite library and move PatternApplicator there.
This new library will be focused on applying rewrites, and will also include compiling rewrites with PDL.

Differential Revision: https://reviews.llvm.org/D89103
This commit is contained in:
River Riddle 2020-10-26 17:24:17 -07:00
parent b99bd77162
commit b6eb26fd0e
42 changed files with 358 additions and 293 deletions

View File

@ -416,12 +416,9 @@ private:
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
};
//===----------------------------------------------------------------------===//
// Pattern-driven rewriters
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// OwningRewritePatternList
//===----------------------------------------------------------------------===//
class OwningRewritePatternList {
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
@ -481,98 +478,6 @@ private:
PatternListT patterns;
};
//===----------------------------------------------------------------------===//
// PatternApplicator
/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
public:
/// The cost model dynamically assigns a PatternBenefit to a particular
/// pattern. Users can query contained patterns and pass analysis results to
/// applyCostModel. Patterns to be discarded should have a benefit of
/// `impossibleToMatch`.
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
: owningPatternList(owningPatternList) {}
/// Attempt to match and rewrite the given op with any pattern, allowing a
/// predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
///
/// canApply: called before each match and rewrite attempt; return false to
/// skip pattern.
/// onFailure: called when a pattern fails to match to perform cleanup.
/// onSuccess: called when a pattern match succeeds; return failure() to
/// invalidate the match and try another pattern.
LogicalResult
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply = {},
function_ref<void(const Pattern &)> onFailure = {},
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
/// Apply a cost model to the patterns within this applicator.
void applyCostModel(CostModel model);
/// Apply the default cost model that solely uses the pattern's static
/// benefit.
void applyDefaultCostModel() {
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
}
/// Walk all of the patterns within the applicator.
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
private:
/// Attempt to match and rewrite the given op with the given pattern, allowing
/// a predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
LogicalResult
matchAndRewrite(Operation *op, const RewritePattern &pattern,
PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess);
/// The list that owns the patterns used within this applicator.
const OwningRewritePatternList &owningPatternList;
/// The set of patterns to match for each operation, stable sorted by benefit.
DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
/// The set of patterns that may match against any operation type, stable
/// sorted by benefit.
SmallVector<RewritePattern *, 1> anyOpPatterns;
};
//===----------------------------------------------------------------------===//
// applyPatternsGreedily
//===----------------------------------------------------------------------===//
/// Rewrite the regions of the specified operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
/// work-list driven manner. Return success if no more patterns can be matched
/// in the result operation regions.
/// Note: This does not apply patterns to the top-level operation itself. Note:
/// These methods also perform folding and simple dead-code elimination
/// before attempting to match any of the provided patterns.
///
LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const OwningRewritePatternList &patterns);
/// Rewrite the given regions, which must be isolated from above.
LogicalResult
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
const OwningRewritePatternList &patterns);
/// Applies the specified patterns on `op` alone while also trying to fold it,
/// by selecting the highest benefits patterns in a greedy manner. Returns
/// success if no more patterns can be matched. `erased` is set to true if `op`
/// was folded away or erased as a result of becoming dead. Note: This does not
/// apply any patterns recursively to the regions of `op`.
LogicalResult applyOpPatternsAndFold(Operation *op,
const OwningRewritePatternList &patterns,
bool *erased = nullptr);
} // end namespace mlir
#endif // MLIR_PATTERN_MATCH_H

View File

@ -0,0 +1,85 @@
//===- PatternApplicator.h - PatternApplicator -------==---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an applicator that applies pattern rewrites based upon a
// user defined cost model.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
#define MLIR_REWRITE_PATTERNAPPLICATOR_H
#include "mlir/IR/PatternMatch.h"
namespace mlir {
class PatternRewriter;
/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
public:
/// The cost model dynamically assigns a PatternBenefit to a particular
/// pattern. Users can query contained patterns and pass analysis results to
/// applyCostModel. Patterns to be discarded should have a benefit of
/// `impossibleToMatch`.
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
: owningPatternList(owningPatternList) {}
/// Attempt to match and rewrite the given op with any pattern, allowing a
/// predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
///
/// canApply: called before each match and rewrite attempt; return false to
/// skip pattern.
/// onFailure: called when a pattern fails to match to perform cleanup.
/// onSuccess: called when a pattern match succeeds; return failure() to
/// invalidate the match and try another pattern.
LogicalResult
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply = {},
function_ref<void(const Pattern &)> onFailure = {},
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
/// Apply a cost model to the patterns within this applicator.
void applyCostModel(CostModel model);
/// Apply the default cost model that solely uses the pattern's static
/// benefit.
void applyDefaultCostModel() {
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
}
/// Walk all of the patterns within the applicator.
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
private:
/// Attempt to match and rewrite the given op with the given pattern, allowing
/// a predicate to decide if a pattern can be applied or not, and hooks for if
/// the pattern match was a success or failure.
LogicalResult
matchAndRewrite(Operation *op, const RewritePattern &pattern,
PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess);
/// The list that owns the patterns used within this applicator.
const OwningRewritePatternList &owningPatternList;
/// The set of patterns to match for each operation, stable sorted by benefit.
DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
/// The set of patterns that may match against any operation type, stable
/// sorted by benefit.
SmallVector<RewritePattern *, 1> anyOpPatterns;
};
} // end namespace mlir
#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H

View File

@ -0,0 +1,52 @@
//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares methods for applying a set of patterns greedily, choosing
// the patterns with the highest local benefit, until a fixed point is reached.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
#define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
#include "mlir/IR/PatternMatch.h"
namespace mlir {
//===----------------------------------------------------------------------===//
// applyPatternsGreedily
//===----------------------------------------------------------------------===//
/// Rewrite the regions of the specified operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
/// work-list driven manner. Return success if no more patterns can be matched
/// in the result operation regions.
/// Note: This does not apply patterns to the top-level operation itself. Note:
/// These methods also perform folding and simple dead-code elimination
/// before attempting to match any of the provided patterns.
///
LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const OwningRewritePatternList &patterns);
/// Rewrite the given regions, which must be isolated from above.
LogicalResult
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
const OwningRewritePatternList &patterns);
/// Applies the specified patterns on `op` alone while also trying to fold it,
/// by selecting the highest benefits patterns in a greedy manner. Returns
/// success if no more patterns can be matched. `erased` is set to true if `op`
/// was folded away or erased as a result of becoming dead. Note: This does not
/// apply any patterns recursively to the regions of `op`.
LogicalResult applyOpPatternsAndFold(Operation *op,
const OwningRewritePatternList &patterns,
bool *erased = nullptr);
} // end namespace mlir
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_

View File

@ -12,6 +12,7 @@ add_subdirectory(Interfaces)
add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Reducer)
add_subdirectory(Rewrite)
add_subdirectory(Support)
add_subdirectory(TableGen)
add_subdirectory(Target)

View File

@ -19,6 +19,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"
#include "../GPUCommon/GPUOpsLowering.h"

View File

@ -22,6 +22,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"
#include "../GPUCommon/GPUOpsLowering.h"

View File

@ -15,6 +15,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@ -123,7 +124,7 @@ class ConvertShapeConstraints
OwningRewritePatternList patterns;
populateConvertShapeConstraintsConversionPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -17,8 +17,8 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;

View File

@ -15,16 +15,12 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Target/LLVMIR/TypeTranslation.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"

View File

@ -24,14 +24,10 @@
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;

View File

@ -24,7 +24,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/CommandLine.h"

View File

@ -15,7 +15,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Utils.h"
#define DEBUG_TYPE "simplify-affine-structure"

View File

@ -16,7 +16,7 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

View File

@ -17,6 +17,7 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"

View File

@ -20,9 +20,8 @@
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"

View File

@ -23,9 +23,9 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"

View File

@ -20,6 +20,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;

View File

@ -22,6 +22,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

View File

@ -22,8 +22,8 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/CommandLine.h"

View File

@ -21,9 +21,9 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>

View File

@ -12,10 +12,9 @@
#include "mlir/Dialect/Quant/QuantizeUtils.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::quant;

View File

@ -11,9 +11,8 @@
#include "mlir/Dialect/Quant/Passes.h"
#include "mlir/Dialect/Quant/QuantOps.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::quant;

View File

@ -10,6 +10,7 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;

View File

@ -8,14 +8,9 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
#define DEBUG_TYPE "pattern-match"
//===----------------------------------------------------------------------===//
// PatternBenefit
//===----------------------------------------------------------------------===//
@ -205,135 +200,3 @@ void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
cloneRegionBefore(region, *before->getParent(), before->getIterator());
}
//===----------------------------------------------------------------------===//
// PatternApplicator
//===----------------------------------------------------------------------===//
void PatternApplicator::applyCostModel(CostModel model) {
// Separate patterns by root kind to simplify lookup later on.
patterns.clear();
anyOpPatterns.clear();
for (const auto &pat : owningPatternList) {
// If the pattern is always impossible to match, just ignore it.
if (pat->getBenefit().isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs()
<< "Ignoring pattern '" << pat->getRootKind()
<< "' because it is impossible to match (by pattern benefit)\n";
});
continue;
}
if (Optional<OperationName> opName = pat->getRootKind())
patterns[*opName].push_back(pat.get());
else
anyOpPatterns.push_back(pat.get());
}
// Sort the patterns using the provided cost model.
llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
return benefits[lhs] > benefits[rhs];
};
auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
// Special case for one pattern in the list, which is the most common case.
if (list.size() == 1) {
if (model(*list.front()).isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
<< "' because it is impossible to match or cannot lead "
"to legal IR (by cost model)\n";
});
list.clear();
}
return;
}
// Collect the dynamic benefits for the current pattern list.
benefits.clear();
for (RewritePattern *pat : list)
benefits.try_emplace(pat, model(*pat));
// Sort patterns with highest benefit first, and remove those that are
// impossible to match.
std::stable_sort(list.begin(), list.end(), cmp);
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
<< "' because it is impossible to match or cannot lead to "
"legal IR (by cost model)\n";
});
list.pop_back();
}
};
for (auto &it : patterns)
processPatternList(it.second);
processPatternList(anyOpPatterns);
}
void PatternApplicator::walkAllPatterns(
function_ref<void(const Pattern &)> walk) {
for (auto &it : owningPatternList)
walk(*it);
}
LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Check to see if there are patterns matching this specific operation type.
MutableArrayRef<RewritePattern *> opPatterns;
auto patternIt = patterns.find(op->getName());
if (patternIt != patterns.end())
opPatterns = patternIt->second;
// Process the patterns for that match the specific operation type, and any
// operation type in an interleaved fashion.
// FIXME: It'd be nice to just write an llvm::make_merge_range utility
// and pass in a comparison function. That would make this code trivial.
auto opIt = opPatterns.begin(), opE = opPatterns.end();
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
while (opIt != opE && anyIt != anyE) {
// Try to match the pattern providing the most benefit.
RewritePattern *pattern;
if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
pattern = *(opIt++);
else
pattern = *(anyIt++);
// Otherwise, try to match the generic pattern.
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
onSuccess)))
return success();
}
// If we break from the loop, then only one of the ranges can still have
// elements. Loop over both without checking given that we don't need to
// interleave anymore.
for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
onSuccess)))
return success();
}
return failure();
}
LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Check that the pattern can be applied.
if (canApply && !canApply(pattern))
return failure();
// Try to match and rewrite this pattern. The patterns are sorted by
// benefit, so if we match we can immediately rewrite.
rewriter.setInsertionPoint(op);
if (succeeded(pattern.matchAndRewrite(op, rewriter)))
return success(!onSuccess || succeeded(onSuccess(pattern)));
if (onFailure)
onFailure(pattern);
return failure();
}

View File

@ -0,0 +1,12 @@
add_mlir_library(MLIRRewrite
PatternApplicator.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
DEPENDS
mlir-generic-headers
LINK_LIBS PUBLIC
MLIRIR
)

View File

@ -0,0 +1,148 @@
//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an applicator that applies pattern rewrites based upon a
// user defined cost model.
//
//===----------------------------------------------------------------------===//
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
#define DEBUG_TYPE "pattern-match"
void PatternApplicator::applyCostModel(CostModel model) {
// Separate patterns by root kind to simplify lookup later on.
patterns.clear();
anyOpPatterns.clear();
for (const auto &pat : owningPatternList) {
// If the pattern is always impossible to match, just ignore it.
if (pat->getBenefit().isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs()
<< "Ignoring pattern '" << pat->getRootKind()
<< "' because it is impossible to match (by pattern benefit)\n";
});
continue;
}
if (Optional<OperationName> opName = pat->getRootKind())
patterns[*opName].push_back(pat.get());
else
anyOpPatterns.push_back(pat.get());
}
// Sort the patterns using the provided cost model.
llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
return benefits[lhs] > benefits[rhs];
};
auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
// Special case for one pattern in the list, which is the most common case.
if (list.size() == 1) {
if (model(*list.front()).isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
<< "' because it is impossible to match or cannot lead "
"to legal IR (by cost model)\n";
});
list.clear();
}
return;
}
// Collect the dynamic benefits for the current pattern list.
benefits.clear();
for (RewritePattern *pat : list)
benefits.try_emplace(pat, model(*pat));
// Sort patterns with highest benefit first, and remove those that are
// impossible to match.
std::stable_sort(list.begin(), list.end(), cmp);
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
LLVM_DEBUG({
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
<< "' because it is impossible to match or cannot lead to "
"legal IR (by cost model)\n";
});
list.pop_back();
}
};
for (auto &it : patterns)
processPatternList(it.second);
processPatternList(anyOpPatterns);
}
void PatternApplicator::walkAllPatterns(
function_ref<void(const Pattern &)> walk) {
for (auto &it : owningPatternList)
walk(*it);
}
LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Check to see if there are patterns matching this specific operation type.
MutableArrayRef<RewritePattern *> opPatterns;
auto patternIt = patterns.find(op->getName());
if (patternIt != patterns.end())
opPatterns = patternIt->second;
// Process the patterns for that match the specific operation type, and any
// operation type in an interleaved fashion.
// FIXME: It'd be nice to just write an llvm::make_merge_range utility
// and pass in a comparison function. That would make this code trivial.
auto opIt = opPatterns.begin(), opE = opPatterns.end();
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
while (opIt != opE && anyIt != anyE) {
// Try to match the pattern providing the most benefit.
RewritePattern *pattern;
if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
pattern = *(opIt++);
else
pattern = *(anyIt++);
// Otherwise, try to match the generic pattern.
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
onSuccess)))
return success();
}
// If we break from the loop, then only one of the ranges can still have
// elements. Loop over both without checking given that we don't need to
// interleave anymore.
for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
onSuccess)))
return success();
}
return failure();
}
LogicalResult PatternApplicator::matchAndRewrite(
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
function_ref<bool(const Pattern &)> canApply,
function_ref<void(const Pattern &)> onFailure,
function_ref<LogicalResult(const Pattern &)> onSuccess) {
// Check that the pattern can be applied.
if (canApply && !canApply(pattern))
return failure();
// Try to match and rewrite this pattern. The patterns are sorted by
// benefit, so if we match we can immediately rewrite.
rewriter.setInsertionPoint(op);
if (succeeded(pattern.matchAndRewrite(op, rewriter)))
return success(!onSuccess || succeeded(onSuccess(pattern)));
if (onFailure)
onFailure(pattern);
return failure();
}

View File

@ -7,7 +7,6 @@ add_mlir_library(MLIRTransforms
Canonicalizer.cpp
CopyRemoval.cpp
CSE.cpp
DialectConversion.cpp
Inliner.cpp
LocationSnapshot.cpp
LoopCoalescing.cpp

View File

@ -12,8 +12,8 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;

View File

@ -17,6 +17,7 @@
#include "mlir/Analysis/CallGraph.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SCCIterator.h"

View File

@ -1,4 +1,5 @@
add_mlir_library(MLIRTransformUtils
DialectConversion.cpp
FoldUtils.cpp
GreedyPatternRewriteDriver.cpp
InliningUtils.cpp
@ -19,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
MLIRLoopAnalysis
MLIRSCF
MLIRPass
MLIRRewrite
MLIRStandard
)

View File

@ -12,6 +12,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
@ -74,8 +75,7 @@ computeConversionSet(iterator_range<Region::iterator> region,
/// A utility function to log a successful result for the given reason.
template <typename... Args>
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
Args &&... args) {
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
LLVM_DEBUG({
os.unindent();
os.startLine() << "} -> SUCCESS";
@ -88,8 +88,7 @@ static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
/// A utility function to log a failure result for the given reason.
template <typename... Args>
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
Args &&... args) {
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
LLVM_DEBUG({
os.unindent();
os.startLine() << "} -> FAILURE : "
@ -2033,21 +2032,21 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
return minDepth;
// Sort the patterns by those likely to be the most beneficial.
llvm::array_pod_sort(
patternsByDepth.begin(), patternsByDepth.end(),
[](const std::pair<const Pattern *, unsigned> *lhs,
const std::pair<const Pattern *, unsigned> *rhs) {
// First sort by the smaller pattern legalization depth.
if (lhs->second != rhs->second)
return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
&rhs->second);
llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
[](const std::pair<const Pattern *, unsigned> *lhs,
const std::pair<const Pattern *, unsigned> *rhs) {
// First sort by the smaller pattern legalization
// depth.
if (lhs->second != rhs->second)
return llvm::array_pod_sort_comparator<unsigned>(
&lhs->second, &rhs->second);
// Then sort by the larger pattern benefit.
auto lhsBenefit = lhs->first->getBenefit();
auto rhsBenefit = rhs->first->getBenefit();
return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
&lhsBenefit);
});
// Then sort by the larger pattern benefit.
auto lhsBenefit = lhs->first->getBenefit();
auto rhsBenefit = rhs->first->getBenefit();
return llvm::array_pod_sort_comparator<PatternBenefit>(
&rhsBenefit, &lhsBenefit);
});
// Update the legalization pattern to use the new sorted list.
patterns.clear();

View File

@ -10,8 +10,9 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/DenseMap.h"

View File

@ -23,8 +23,8 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"

View File

@ -13,8 +13,8 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"

View File

@ -10,10 +10,10 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;

View File

@ -7,9 +7,8 @@
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@ -25,9 +24,9 @@ OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
ArrayRef<Attribute> operands) {
auto argument_op = getOperand();
auto argumentOp = getOperand();
// The success case should cause the trait fold to be supressed.
return argument_op.getDefiningOp() ? argument_op : OpFoldResult{};
return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};
}
namespace {

View File

@ -14,6 +14,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Passes.h"
@ -93,7 +94,7 @@ void TestConvVectorization::runOnOperation() {
// VectorTransforms.cpp
vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
applyPatternsAndFoldGreedily(module, vectorTransferPatterns);
applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
// Programmatic controlled lowering of linalg.copy and linalg.fill.
PassManager pm(context);
@ -105,13 +106,14 @@ void TestConvVectorization::runOnOperation() {
OwningRewritePatternList vectorContractLoweringPatterns;
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
context, vectorTransformsOptions);
applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns);
applyPatternsAndFoldGreedily(module,
std::move(vectorContractLoweringPatterns));
// Programmatic controlled lowering of vector.transfer only.
OwningRewritePatternList vectorToLoopsPatterns;
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
VectorTransferToSCFOptions());
applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns);
applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
// Ensure we drop the marker in the end.
module.walk([](linalg::LinalgOp op) {

View File

@ -11,8 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;

View File

@ -12,8 +12,8 @@
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;

View File

@ -13,6 +13,7 @@
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;

View File

@ -17,8 +17,8 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SetVector.h"

View File

@ -14,9 +14,8 @@
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::vector;