[mlir][NFC] Move around the code related to PatternRewriting to improve layering
There are several pieces of pattern rewriting infra in IR/ that really shouldn't be there. This revision moves those pieces to a better location such that they are easier to evolve in the future(e.g. with PDL). More concretely this revision does the following: * Create a Transforms/GreedyPatternRewriteDriver.h and move the apply*andFold methods there. The definitions for these methods are already in Transforms/ so it doesn't make sense for the declarations to be in IR. * Create a new lib/Rewrite library and move PatternApplicator there. This new library will be focused on applying rewrites, and will also include compiling rewrites with PDL. Differential Revision: https://reviews.llvm.org/D89103
This commit is contained in:
parent
b99bd77162
commit
b6eb26fd0e
@ -416,12 +416,9 @@ private:
|
||||
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-driven rewriters
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OwningRewritePatternList
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class OwningRewritePatternList {
|
||||
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
|
||||
@ -481,98 +478,6 @@ private:
|
||||
PatternListT patterns;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternApplicator
|
||||
|
||||
/// This class manages the application of a group of rewrite patterns, with a
|
||||
/// user-provided cost model.
|
||||
class PatternApplicator {
|
||||
public:
|
||||
/// The cost model dynamically assigns a PatternBenefit to a particular
|
||||
/// pattern. Users can query contained patterns and pass analysis results to
|
||||
/// applyCostModel. Patterns to be discarded should have a benefit of
|
||||
/// `impossibleToMatch`.
|
||||
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
|
||||
|
||||
explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
|
||||
: owningPatternList(owningPatternList) {}
|
||||
|
||||
/// Attempt to match and rewrite the given op with any pattern, allowing a
|
||||
/// predicate to decide if a pattern can be applied or not, and hooks for if
|
||||
/// the pattern match was a success or failure.
|
||||
///
|
||||
/// canApply: called before each match and rewrite attempt; return false to
|
||||
/// skip pattern.
|
||||
/// onFailure: called when a pattern fails to match to perform cleanup.
|
||||
/// onSuccess: called when a pattern match succeeds; return failure() to
|
||||
/// invalidate the match and try another pattern.
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
|
||||
function_ref<bool(const Pattern &)> canApply = {},
|
||||
function_ref<void(const Pattern &)> onFailure = {},
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
|
||||
|
||||
/// Apply a cost model to the patterns within this applicator.
|
||||
void applyCostModel(CostModel model);
|
||||
|
||||
/// Apply the default cost model that solely uses the pattern's static
|
||||
/// benefit.
|
||||
void applyDefaultCostModel() {
|
||||
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
|
||||
}
|
||||
|
||||
/// Walk all of the patterns within the applicator.
|
||||
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
|
||||
|
||||
private:
|
||||
/// Attempt to match and rewrite the given op with the given pattern, allowing
|
||||
/// a predicate to decide if a pattern can be applied or not, and hooks for if
|
||||
/// the pattern match was a success or failure.
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, const RewritePattern &pattern,
|
||||
PatternRewriter &rewriter,
|
||||
function_ref<bool(const Pattern &)> canApply,
|
||||
function_ref<void(const Pattern &)> onFailure,
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess);
|
||||
|
||||
/// The list that owns the patterns used within this applicator.
|
||||
const OwningRewritePatternList &owningPatternList;
|
||||
|
||||
/// The set of patterns to match for each operation, stable sorted by benefit.
|
||||
DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
|
||||
/// The set of patterns that may match against any operation type, stable
|
||||
/// sorted by benefit.
|
||||
SmallVector<RewritePattern *, 1> anyOpPatterns;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// applyPatternsGreedily
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Rewrite the regions of the specified operation, which must be isolated from
|
||||
/// above, by repeatedly applying the highest benefit patterns in a greedy
|
||||
/// work-list driven manner. Return success if no more patterns can be matched
|
||||
/// in the result operation regions.
|
||||
/// Note: This does not apply patterns to the top-level operation itself. Note:
|
||||
/// These methods also perform folding and simple dead-code elimination
|
||||
/// before attempting to match any of the provided patterns.
|
||||
///
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(Operation *op,
|
||||
const OwningRewritePatternList &patterns);
|
||||
/// Rewrite the given regions, which must be isolated from above.
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||
const OwningRewritePatternList &patterns);
|
||||
|
||||
/// Applies the specified patterns on `op` alone while also trying to fold it,
|
||||
/// by selecting the highest benefits patterns in a greedy manner. Returns
|
||||
/// success if no more patterns can be matched. `erased` is set to true if `op`
|
||||
/// was folded away or erased as a result of becoming dead. Note: This does not
|
||||
/// apply any patterns recursively to the regions of `op`.
|
||||
LogicalResult applyOpPatternsAndFold(Operation *op,
|
||||
const OwningRewritePatternList &patterns,
|
||||
bool *erased = nullptr);
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_PATTERN_MATCH_H
|
||||
|
85
mlir/include/mlir/Rewrite/PatternApplicator.h
Normal file
85
mlir/include/mlir/Rewrite/PatternApplicator.h
Normal file
@ -0,0 +1,85 @@
|
||||
//===- PatternApplicator.h - PatternApplicator -------==---------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements an applicator that applies pattern rewrites based upon a
|
||||
// user defined cost model.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
|
||||
#define MLIR_REWRITE_PATTERNAPPLICATOR_H
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
class PatternRewriter;
|
||||
|
||||
/// This class manages the application of a group of rewrite patterns, with a
|
||||
/// user-provided cost model.
|
||||
class PatternApplicator {
|
||||
public:
|
||||
/// The cost model dynamically assigns a PatternBenefit to a particular
|
||||
/// pattern. Users can query contained patterns and pass analysis results to
|
||||
/// applyCostModel. Patterns to be discarded should have a benefit of
|
||||
/// `impossibleToMatch`.
|
||||
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
|
||||
|
||||
explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
|
||||
: owningPatternList(owningPatternList) {}
|
||||
|
||||
/// Attempt to match and rewrite the given op with any pattern, allowing a
|
||||
/// predicate to decide if a pattern can be applied or not, and hooks for if
|
||||
/// the pattern match was a success or failure.
|
||||
///
|
||||
/// canApply: called before each match and rewrite attempt; return false to
|
||||
/// skip pattern.
|
||||
/// onFailure: called when a pattern fails to match to perform cleanup.
|
||||
/// onSuccess: called when a pattern match succeeds; return failure() to
|
||||
/// invalidate the match and try another pattern.
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
|
||||
function_ref<bool(const Pattern &)> canApply = {},
|
||||
function_ref<void(const Pattern &)> onFailure = {},
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
|
||||
|
||||
/// Apply a cost model to the patterns within this applicator.
|
||||
void applyCostModel(CostModel model);
|
||||
|
||||
/// Apply the default cost model that solely uses the pattern's static
|
||||
/// benefit.
|
||||
void applyDefaultCostModel() {
|
||||
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
|
||||
}
|
||||
|
||||
/// Walk all of the patterns within the applicator.
|
||||
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
|
||||
|
||||
private:
|
||||
/// Attempt to match and rewrite the given op with the given pattern, allowing
|
||||
/// a predicate to decide if a pattern can be applied or not, and hooks for if
|
||||
/// the pattern match was a success or failure.
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, const RewritePattern &pattern,
|
||||
PatternRewriter &rewriter,
|
||||
function_ref<bool(const Pattern &)> canApply,
|
||||
function_ref<void(const Pattern &)> onFailure,
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess);
|
||||
|
||||
/// The list that owns the patterns used within this applicator.
|
||||
const OwningRewritePatternList &owningPatternList;
|
||||
|
||||
/// The set of patterns to match for each operation, stable sorted by benefit.
|
||||
DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
|
||||
/// The set of patterns that may match against any operation type, stable
|
||||
/// sorted by benefit.
|
||||
SmallVector<RewritePattern *, 1> anyOpPatterns;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H
|
52
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Normal file
52
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Normal file
@ -0,0 +1,52 @@
|
||||
//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares methods for applying a set of patterns greedily, choosing
|
||||
// the patterns with the highest local benefit, until a fixed point is reached.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
|
||||
#define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// applyPatternsGreedily
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Rewrite the regions of the specified operation, which must be isolated from
|
||||
/// above, by repeatedly applying the highest benefit patterns in a greedy
|
||||
/// work-list driven manner. Return success if no more patterns can be matched
|
||||
/// in the result operation regions.
|
||||
/// Note: This does not apply patterns to the top-level operation itself. Note:
|
||||
/// These methods also perform folding and simple dead-code elimination
|
||||
/// before attempting to match any of the provided patterns.
|
||||
///
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(Operation *op,
|
||||
const OwningRewritePatternList &patterns);
|
||||
/// Rewrite the given regions, which must be isolated from above.
|
||||
LogicalResult
|
||||
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||
const OwningRewritePatternList &patterns);
|
||||
|
||||
/// Applies the specified patterns on `op` alone while also trying to fold it,
|
||||
/// by selecting the highest benefits patterns in a greedy manner. Returns
|
||||
/// success if no more patterns can be matched. `erased` is set to true if `op`
|
||||
/// was folded away or erased as a result of becoming dead. Note: This does not
|
||||
/// apply any patterns recursively to the regions of `op`.
|
||||
LogicalResult applyOpPatternsAndFold(Operation *op,
|
||||
const OwningRewritePatternList &patterns,
|
||||
bool *erased = nullptr);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
|
@ -12,6 +12,7 @@ add_subdirectory(Interfaces)
|
||||
add_subdirectory(Parser)
|
||||
add_subdirectory(Pass)
|
||||
add_subdirectory(Reducer)
|
||||
add_subdirectory(Rewrite)
|
||||
add_subdirectory(Support)
|
||||
add_subdirectory(TableGen)
|
||||
add_subdirectory(Target)
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include "../GPUCommon/GPUOpsLowering.h"
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include "../GPUCommon/GPUOpsLowering.h"
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@ -123,7 +124,7 @@ class ConvertShapeConstraints
|
||||
OwningRewritePatternList patterns;
|
||||
populateConvertShapeConstraintsConversionPatterns(patterns, context);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
|
||||
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -17,8 +17,8 @@
|
||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -15,16 +15,12 @@
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Target/LLVMIR/TypeTranslation.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/IR/DerivedTypes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
@ -24,14 +24,10 @@
|
||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -24,7 +24,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
@ -15,7 +15,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Affine/Passes.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
||||
#define DEBUG_TYPE "simplify-affine-structure"
|
||||
|
@ -16,7 +16,7 @@
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
@ -20,9 +20,8 @@
|
||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
|
@ -23,9 +23,9 @@
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Dominance.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
|
@ -22,8 +22,8 @@
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineExprVisitor.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
||||
|
@ -21,9 +21,9 @@
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <type_traits>
|
||||
|
@ -12,10 +12,9 @@
|
||||
#include "mlir/Dialect/Quant/QuantizeUtils.h"
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quant;
|
||||
|
@ -11,9 +11,8 @@
|
||||
#include "mlir/Dialect/Quant/Passes.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||
#include "mlir/Dialect/Quant/UniformSupport.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::quant;
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -8,14 +8,9 @@
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define DEBUG_TYPE "pattern-match"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternBenefit
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -205,135 +200,3 @@ void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) {
|
||||
cloneRegionBefore(region, *before->getParent(), before->getIterator());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PatternApplicator
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void PatternApplicator::applyCostModel(CostModel model) {
|
||||
// Separate patterns by root kind to simplify lookup later on.
|
||||
patterns.clear();
|
||||
anyOpPatterns.clear();
|
||||
for (const auto &pat : owningPatternList) {
|
||||
// If the pattern is always impossible to match, just ignore it.
|
||||
if (pat->getBenefit().isImpossibleToMatch()) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs()
|
||||
<< "Ignoring pattern '" << pat->getRootKind()
|
||||
<< "' because it is impossible to match (by pattern benefit)\n";
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (Optional<OperationName> opName = pat->getRootKind())
|
||||
patterns[*opName].push_back(pat.get());
|
||||
else
|
||||
anyOpPatterns.push_back(pat.get());
|
||||
}
|
||||
|
||||
// Sort the patterns using the provided cost model.
|
||||
llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
|
||||
auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
|
||||
return benefits[lhs] > benefits[rhs];
|
||||
};
|
||||
auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
|
||||
// Special case for one pattern in the list, which is the most common case.
|
||||
if (list.size() == 1) {
|
||||
if (model(*list.front()).isImpossibleToMatch()) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
|
||||
<< "' because it is impossible to match or cannot lead "
|
||||
"to legal IR (by cost model)\n";
|
||||
});
|
||||
list.clear();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect the dynamic benefits for the current pattern list.
|
||||
benefits.clear();
|
||||
for (RewritePattern *pat : list)
|
||||
benefits.try_emplace(pat, model(*pat));
|
||||
|
||||
// Sort patterns with highest benefit first, and remove those that are
|
||||
// impossible to match.
|
||||
std::stable_sort(list.begin(), list.end(), cmp);
|
||||
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
|
||||
<< "' because it is impossible to match or cannot lead to "
|
||||
"legal IR (by cost model)\n";
|
||||
});
|
||||
list.pop_back();
|
||||
}
|
||||
};
|
||||
for (auto &it : patterns)
|
||||
processPatternList(it.second);
|
||||
processPatternList(anyOpPatterns);
|
||||
}
|
||||
|
||||
void PatternApplicator::walkAllPatterns(
|
||||
function_ref<void(const Pattern &)> walk) {
|
||||
for (auto &it : owningPatternList)
|
||||
walk(*it);
|
||||
}
|
||||
|
||||
LogicalResult PatternApplicator::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter,
|
||||
function_ref<bool(const Pattern &)> canApply,
|
||||
function_ref<void(const Pattern &)> onFailure,
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
||||
// Check to see if there are patterns matching this specific operation type.
|
||||
MutableArrayRef<RewritePattern *> opPatterns;
|
||||
auto patternIt = patterns.find(op->getName());
|
||||
if (patternIt != patterns.end())
|
||||
opPatterns = patternIt->second;
|
||||
|
||||
// Process the patterns for that match the specific operation type, and any
|
||||
// operation type in an interleaved fashion.
|
||||
// FIXME: It'd be nice to just write an llvm::make_merge_range utility
|
||||
// and pass in a comparison function. That would make this code trivial.
|
||||
auto opIt = opPatterns.begin(), opE = opPatterns.end();
|
||||
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
|
||||
while (opIt != opE && anyIt != anyE) {
|
||||
// Try to match the pattern providing the most benefit.
|
||||
RewritePattern *pattern;
|
||||
if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
|
||||
pattern = *(opIt++);
|
||||
else
|
||||
pattern = *(anyIt++);
|
||||
|
||||
// Otherwise, try to match the generic pattern.
|
||||
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
|
||||
onSuccess)))
|
||||
return success();
|
||||
}
|
||||
// If we break from the loop, then only one of the ranges can still have
|
||||
// elements. Loop over both without checking given that we don't need to
|
||||
// interleave anymore.
|
||||
for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
|
||||
llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
|
||||
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
|
||||
onSuccess)))
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult PatternApplicator::matchAndRewrite(
|
||||
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
|
||||
function_ref<bool(const Pattern &)> canApply,
|
||||
function_ref<void(const Pattern &)> onFailure,
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
||||
// Check that the pattern can be applied.
|
||||
if (canApply && !canApply(pattern))
|
||||
return failure();
|
||||
|
||||
// Try to match and rewrite this pattern. The patterns are sorted by
|
||||
// benefit, so if we match we can immediately rewrite.
|
||||
rewriter.setInsertionPoint(op);
|
||||
if (succeeded(pattern.matchAndRewrite(op, rewriter)))
|
||||
return success(!onSuccess || succeeded(onSuccess(pattern)));
|
||||
|
||||
if (onFailure)
|
||||
onFailure(pattern);
|
||||
return failure();
|
||||
}
|
||||
|
12
mlir/lib/Rewrite/CMakeLists.txt
Normal file
12
mlir/lib/Rewrite/CMakeLists.txt
Normal file
@ -0,0 +1,12 @@
|
||||
add_mlir_library(MLIRRewrite
|
||||
PatternApplicator.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
|
||||
|
||||
DEPENDS
|
||||
mlir-generic-headers
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
)
|
148
mlir/lib/Rewrite/PatternApplicator.cpp
Normal file
148
mlir/lib/Rewrite/PatternApplicator.cpp
Normal file
@ -0,0 +1,148 @@
|
||||
//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements an applicator that applies pattern rewrites based upon a
|
||||
// user defined cost model.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define DEBUG_TYPE "pattern-match"
|
||||
|
||||
void PatternApplicator::applyCostModel(CostModel model) {
|
||||
// Separate patterns by root kind to simplify lookup later on.
|
||||
patterns.clear();
|
||||
anyOpPatterns.clear();
|
||||
for (const auto &pat : owningPatternList) {
|
||||
// If the pattern is always impossible to match, just ignore it.
|
||||
if (pat->getBenefit().isImpossibleToMatch()) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs()
|
||||
<< "Ignoring pattern '" << pat->getRootKind()
|
||||
<< "' because it is impossible to match (by pattern benefit)\n";
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (Optional<OperationName> opName = pat->getRootKind())
|
||||
patterns[*opName].push_back(pat.get());
|
||||
else
|
||||
anyOpPatterns.push_back(pat.get());
|
||||
}
|
||||
|
||||
// Sort the patterns using the provided cost model.
|
||||
llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
|
||||
auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
|
||||
return benefits[lhs] > benefits[rhs];
|
||||
};
|
||||
auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
|
||||
// Special case for one pattern in the list, which is the most common case.
|
||||
if (list.size() == 1) {
|
||||
if (model(*list.front()).isImpossibleToMatch()) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
|
||||
<< "' because it is impossible to match or cannot lead "
|
||||
"to legal IR (by cost model)\n";
|
||||
});
|
||||
list.clear();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Collect the dynamic benefits for the current pattern list.
|
||||
benefits.clear();
|
||||
for (RewritePattern *pat : list)
|
||||
benefits.try_emplace(pat, model(*pat));
|
||||
|
||||
// Sort patterns with highest benefit first, and remove those that are
|
||||
// impossible to match.
|
||||
std::stable_sort(list.begin(), list.end(), cmp);
|
||||
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
|
||||
LLVM_DEBUG({
|
||||
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
|
||||
<< "' because it is impossible to match or cannot lead to "
|
||||
"legal IR (by cost model)\n";
|
||||
});
|
||||
list.pop_back();
|
||||
}
|
||||
};
|
||||
for (auto &it : patterns)
|
||||
processPatternList(it.second);
|
||||
processPatternList(anyOpPatterns);
|
||||
}
|
||||
|
||||
void PatternApplicator::walkAllPatterns(
|
||||
function_ref<void(const Pattern &)> walk) {
|
||||
for (auto &it : owningPatternList)
|
||||
walk(*it);
|
||||
}
|
||||
|
||||
LogicalResult PatternApplicator::matchAndRewrite(
|
||||
Operation *op, PatternRewriter &rewriter,
|
||||
function_ref<bool(const Pattern &)> canApply,
|
||||
function_ref<void(const Pattern &)> onFailure,
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
||||
// Check to see if there are patterns matching this specific operation type.
|
||||
MutableArrayRef<RewritePattern *> opPatterns;
|
||||
auto patternIt = patterns.find(op->getName());
|
||||
if (patternIt != patterns.end())
|
||||
opPatterns = patternIt->second;
|
||||
|
||||
// Process the patterns for that match the specific operation type, and any
|
||||
// operation type in an interleaved fashion.
|
||||
// FIXME: It'd be nice to just write an llvm::make_merge_range utility
|
||||
// and pass in a comparison function. That would make this code trivial.
|
||||
auto opIt = opPatterns.begin(), opE = opPatterns.end();
|
||||
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
|
||||
while (opIt != opE && anyIt != anyE) {
|
||||
// Try to match the pattern providing the most benefit.
|
||||
RewritePattern *pattern;
|
||||
if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
|
||||
pattern = *(opIt++);
|
||||
else
|
||||
pattern = *(anyIt++);
|
||||
|
||||
// Otherwise, try to match the generic pattern.
|
||||
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
|
||||
onSuccess)))
|
||||
return success();
|
||||
}
|
||||
// If we break from the loop, then only one of the ranges can still have
|
||||
// elements. Loop over both without checking given that we don't need to
|
||||
// interleave anymore.
|
||||
for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
|
||||
llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
|
||||
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
|
||||
onSuccess)))
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult PatternApplicator::matchAndRewrite(
|
||||
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
|
||||
function_ref<bool(const Pattern &)> canApply,
|
||||
function_ref<void(const Pattern &)> onFailure,
|
||||
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
||||
// Check that the pattern can be applied.
|
||||
if (canApply && !canApply(pattern))
|
||||
return failure();
|
||||
|
||||
// Try to match and rewrite this pattern. The patterns are sorted by
|
||||
// benefit, so if we match we can immediately rewrite.
|
||||
rewriter.setInsertionPoint(op);
|
||||
if (succeeded(pattern.matchAndRewrite(op, rewriter)))
|
||||
return success(!onSuccess || succeeded(onSuccess(pattern)));
|
||||
|
||||
if (onFailure)
|
||||
onFailure(pattern);
|
||||
return failure();
|
||||
}
|
@ -7,7 +7,6 @@ add_mlir_library(MLIRTransforms
|
||||
Canonicalizer.cpp
|
||||
CopyRemoval.cpp
|
||||
CSE.cpp
|
||||
DialectConversion.cpp
|
||||
Inliner.cpp
|
||||
LocationSnapshot.cpp
|
||||
LoopCoalescing.cpp
|
||||
|
@ -12,8 +12,8 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "mlir/Analysis/CallGraph.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/ADT/SCCIterator.h"
|
||||
|
@ -1,4 +1,5 @@
|
||||
add_mlir_library(MLIRTransformUtils
|
||||
DialectConversion.cpp
|
||||
FoldUtils.cpp
|
||||
GreedyPatternRewriteDriver.cpp
|
||||
InliningUtils.cpp
|
||||
@ -19,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
|
||||
MLIRLoopAnalysis
|
||||
MLIRSCF
|
||||
MLIRPass
|
||||
MLIRRewrite
|
||||
MLIRStandard
|
||||
)
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
@ -74,8 +75,7 @@ computeConversionSet(iterator_range<Region::iterator> region,
|
||||
|
||||
/// A utility function to log a successful result for the given reason.
|
||||
template <typename... Args>
|
||||
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
|
||||
Args &&... args) {
|
||||
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
|
||||
LLVM_DEBUG({
|
||||
os.unindent();
|
||||
os.startLine() << "} -> SUCCESS";
|
||||
@ -88,8 +88,7 @@ static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
|
||||
|
||||
/// A utility function to log a failure result for the given reason.
|
||||
template <typename... Args>
|
||||
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
|
||||
Args &&... args) {
|
||||
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
|
||||
LLVM_DEBUG({
|
||||
os.unindent();
|
||||
os.startLine() << "} -> FAILURE : "
|
||||
@ -2033,21 +2032,21 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
|
||||
return minDepth;
|
||||
|
||||
// Sort the patterns by those likely to be the most beneficial.
|
||||
llvm::array_pod_sort(
|
||||
patternsByDepth.begin(), patternsByDepth.end(),
|
||||
[](const std::pair<const Pattern *, unsigned> *lhs,
|
||||
const std::pair<const Pattern *, unsigned> *rhs) {
|
||||
// First sort by the smaller pattern legalization depth.
|
||||
if (lhs->second != rhs->second)
|
||||
return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
|
||||
&rhs->second);
|
||||
llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
|
||||
[](const std::pair<const Pattern *, unsigned> *lhs,
|
||||
const std::pair<const Pattern *, unsigned> *rhs) {
|
||||
// First sort by the smaller pattern legalization
|
||||
// depth.
|
||||
if (lhs->second != rhs->second)
|
||||
return llvm::array_pod_sort_comparator<unsigned>(
|
||||
&lhs->second, &rhs->second);
|
||||
|
||||
// Then sort by the larger pattern benefit.
|
||||
auto lhsBenefit = lhs->first->getBenefit();
|
||||
auto rhsBenefit = rhs->first->getBenefit();
|
||||
return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
|
||||
&lhsBenefit);
|
||||
});
|
||||
// Then sort by the larger pattern benefit.
|
||||
auto lhsBenefit = lhs->first->getBenefit();
|
||||
auto rhsBenefit = rhs->first->getBenefit();
|
||||
return llvm::array_pod_sort_comparator<PatternBenefit>(
|
||||
&rhsBenefit, &lhsBenefit);
|
||||
});
|
||||
|
||||
// Update the legalization pattern to use the new sorted list.
|
||||
patterns.clear();
|
@ -10,8 +10,9 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -23,8 +23,8 @@
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -13,8 +13,8 @@
|
||||
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
|
@ -10,10 +10,10 @@
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -7,9 +7,8 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "TestDialect.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@ -25,9 +24,9 @@ OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
|
||||
|
||||
OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
|
||||
ArrayRef<Attribute> operands) {
|
||||
auto argument_op = getOperand();
|
||||
auto argumentOp = getOperand();
|
||||
// The success case should cause the trait fold to be supressed.
|
||||
return argument_op.getDefiningOp() ? argument_op : OpFoldResult{};
|
||||
return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
@ -93,7 +94,7 @@ void TestConvVectorization::runOnOperation() {
|
||||
// VectorTransforms.cpp
|
||||
vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
|
||||
context, vectorTransformsOptions);
|
||||
applyPatternsAndFoldGreedily(module, vectorTransferPatterns);
|
||||
applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
|
||||
|
||||
// Programmatic controlled lowering of linalg.copy and linalg.fill.
|
||||
PassManager pm(context);
|
||||
@ -105,13 +106,14 @@ void TestConvVectorization::runOnOperation() {
|
||||
OwningRewritePatternList vectorContractLoweringPatterns;
|
||||
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
|
||||
context, vectorTransformsOptions);
|
||||
applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns);
|
||||
applyPatternsAndFoldGreedily(module,
|
||||
std::move(vectorContractLoweringPatterns));
|
||||
|
||||
// Programmatic controlled lowering of vector.transfer only.
|
||||
OwningRewritePatternList vectorToLoopsPatterns;
|
||||
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
|
||||
VectorTransferToSCFOptions());
|
||||
applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns);
|
||||
applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
|
||||
|
||||
// Ensure we drop the marker in the end.
|
||||
module.walk([](linalg::LinalgOp op) {
|
||||
|
@ -11,8 +11,8 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -12,8 +12,8 @@
|
||||
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
@ -17,8 +17,8 @@
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
|
@ -14,9 +14,8 @@
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::vector;
|
||||
|
Loading…
x
Reference in New Issue
Block a user