[mlir][NFC] Move around the code related to PatternRewriting to improve layering
There are several pieces of pattern rewriting infra in IR/ that really shouldn't be there. This revision moves those pieces to a better location such that they are easier to evolve in the future(e.g. with PDL). More concretely this revision does the following: * Create a Transforms/GreedyPatternRewriteDriver.h and move the apply*andFold methods there. The definitions for these methods are already in Transforms/ so it doesn't make sense for the declarations to be in IR. * Create a new lib/Rewrite library and move PatternApplicator there. This new library will be focused on applying rewrites, and will also include compiling rewrites with PDL. Differential Revision: https://reviews.llvm.org/D89103
This commit is contained in:
parent
b99bd77162
commit
b6eb26fd0e
@ -416,12 +416,9 @@ private:
|
|||||||
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
|
void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Pattern-driven rewriters
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// OwningRewritePatternList
|
// OwningRewritePatternList
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
class OwningRewritePatternList {
|
class OwningRewritePatternList {
|
||||||
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
|
using PatternListT = std::vector<std::unique_ptr<RewritePattern>>;
|
||||||
@ -481,98 +478,6 @@ private:
|
|||||||
PatternListT patterns;
|
PatternListT patterns;
|
||||||
};
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// PatternApplicator
|
|
||||||
|
|
||||||
/// This class manages the application of a group of rewrite patterns, with a
|
|
||||||
/// user-provided cost model.
|
|
||||||
class PatternApplicator {
|
|
||||||
public:
|
|
||||||
/// The cost model dynamically assigns a PatternBenefit to a particular
|
|
||||||
/// pattern. Users can query contained patterns and pass analysis results to
|
|
||||||
/// applyCostModel. Patterns to be discarded should have a benefit of
|
|
||||||
/// `impossibleToMatch`.
|
|
||||||
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
|
|
||||||
|
|
||||||
explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
|
|
||||||
: owningPatternList(owningPatternList) {}
|
|
||||||
|
|
||||||
/// Attempt to match and rewrite the given op with any pattern, allowing a
|
|
||||||
/// predicate to decide if a pattern can be applied or not, and hooks for if
|
|
||||||
/// the pattern match was a success or failure.
|
|
||||||
///
|
|
||||||
/// canApply: called before each match and rewrite attempt; return false to
|
|
||||||
/// skip pattern.
|
|
||||||
/// onFailure: called when a pattern fails to match to perform cleanup.
|
|
||||||
/// onSuccess: called when a pattern match succeeds; return failure() to
|
|
||||||
/// invalidate the match and try another pattern.
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
|
|
||||||
function_ref<bool(const Pattern &)> canApply = {},
|
|
||||||
function_ref<void(const Pattern &)> onFailure = {},
|
|
||||||
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
|
|
||||||
|
|
||||||
/// Apply a cost model to the patterns within this applicator.
|
|
||||||
void applyCostModel(CostModel model);
|
|
||||||
|
|
||||||
/// Apply the default cost model that solely uses the pattern's static
|
|
||||||
/// benefit.
|
|
||||||
void applyDefaultCostModel() {
|
|
||||||
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Walk all of the patterns within the applicator.
|
|
||||||
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// Attempt to match and rewrite the given op with the given pattern, allowing
|
|
||||||
/// a predicate to decide if a pattern can be applied or not, and hooks for if
|
|
||||||
/// the pattern match was a success or failure.
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(Operation *op, const RewritePattern &pattern,
|
|
||||||
PatternRewriter &rewriter,
|
|
||||||
function_ref<bool(const Pattern &)> canApply,
|
|
||||||
function_ref<void(const Pattern &)> onFailure,
|
|
||||||
function_ref<LogicalResult(const Pattern &)> onSuccess);
|
|
||||||
|
|
||||||
/// The list that owns the patterns used within this applicator.
|
|
||||||
const OwningRewritePatternList &owningPatternList;
|
|
||||||
|
|
||||||
/// The set of patterns to match for each operation, stable sorted by benefit.
|
|
||||||
DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
|
|
||||||
/// The set of patterns that may match against any operation type, stable
|
|
||||||
/// sorted by benefit.
|
|
||||||
SmallVector<RewritePattern *, 1> anyOpPatterns;
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// applyPatternsGreedily
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
/// Rewrite the regions of the specified operation, which must be isolated from
|
|
||||||
/// above, by repeatedly applying the highest benefit patterns in a greedy
|
|
||||||
/// work-list driven manner. Return success if no more patterns can be matched
|
|
||||||
/// in the result operation regions.
|
|
||||||
/// Note: This does not apply patterns to the top-level operation itself. Note:
|
|
||||||
/// These methods also perform folding and simple dead-code elimination
|
|
||||||
/// before attempting to match any of the provided patterns.
|
|
||||||
///
|
|
||||||
LogicalResult
|
|
||||||
applyPatternsAndFoldGreedily(Operation *op,
|
|
||||||
const OwningRewritePatternList &patterns);
|
|
||||||
/// Rewrite the given regions, which must be isolated from above.
|
|
||||||
LogicalResult
|
|
||||||
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
|
||||||
const OwningRewritePatternList &patterns);
|
|
||||||
|
|
||||||
/// Applies the specified patterns on `op` alone while also trying to fold it,
|
|
||||||
/// by selecting the highest benefits patterns in a greedy manner. Returns
|
|
||||||
/// success if no more patterns can be matched. `erased` is set to true if `op`
|
|
||||||
/// was folded away or erased as a result of becoming dead. Note: This does not
|
|
||||||
/// apply any patterns recursively to the regions of `op`.
|
|
||||||
LogicalResult applyOpPatternsAndFold(Operation *op,
|
|
||||||
const OwningRewritePatternList &patterns,
|
|
||||||
bool *erased = nullptr);
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // MLIR_PATTERN_MATCH_H
|
#endif // MLIR_PATTERN_MATCH_H
|
||||||
|
85
mlir/include/mlir/Rewrite/PatternApplicator.h
Normal file
85
mlir/include/mlir/Rewrite/PatternApplicator.h
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
//===- PatternApplicator.h - PatternApplicator -------==---------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements an applicator that applies pattern rewrites based upon a
|
||||||
|
// user defined cost model.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_REWRITE_PATTERNAPPLICATOR_H
|
||||||
|
#define MLIR_REWRITE_PATTERNAPPLICATOR_H
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class PatternRewriter;
|
||||||
|
|
||||||
|
/// This class manages the application of a group of rewrite patterns, with a
|
||||||
|
/// user-provided cost model.
|
||||||
|
class PatternApplicator {
|
||||||
|
public:
|
||||||
|
/// The cost model dynamically assigns a PatternBenefit to a particular
|
||||||
|
/// pattern. Users can query contained patterns and pass analysis results to
|
||||||
|
/// applyCostModel. Patterns to be discarded should have a benefit of
|
||||||
|
/// `impossibleToMatch`.
|
||||||
|
using CostModel = function_ref<PatternBenefit(const Pattern &)>;
|
||||||
|
|
||||||
|
explicit PatternApplicator(const OwningRewritePatternList &owningPatternList)
|
||||||
|
: owningPatternList(owningPatternList) {}
|
||||||
|
|
||||||
|
/// Attempt to match and rewrite the given op with any pattern, allowing a
|
||||||
|
/// predicate to decide if a pattern can be applied or not, and hooks for if
|
||||||
|
/// the pattern match was a success or failure.
|
||||||
|
///
|
||||||
|
/// canApply: called before each match and rewrite attempt; return false to
|
||||||
|
/// skip pattern.
|
||||||
|
/// onFailure: called when a pattern fails to match to perform cleanup.
|
||||||
|
/// onSuccess: called when a pattern match succeeds; return failure() to
|
||||||
|
/// invalidate the match and try another pattern.
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Operation *op, PatternRewriter &rewriter,
|
||||||
|
function_ref<bool(const Pattern &)> canApply = {},
|
||||||
|
function_ref<void(const Pattern &)> onFailure = {},
|
||||||
|
function_ref<LogicalResult(const Pattern &)> onSuccess = {});
|
||||||
|
|
||||||
|
/// Apply a cost model to the patterns within this applicator.
|
||||||
|
void applyCostModel(CostModel model);
|
||||||
|
|
||||||
|
/// Apply the default cost model that solely uses the pattern's static
|
||||||
|
/// benefit.
|
||||||
|
void applyDefaultCostModel() {
|
||||||
|
applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Walk all of the patterns within the applicator.
|
||||||
|
void walkAllPatterns(function_ref<void(const Pattern &)> walk);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Attempt to match and rewrite the given op with the given pattern, allowing
|
||||||
|
/// a predicate to decide if a pattern can be applied or not, and hooks for if
|
||||||
|
/// the pattern match was a success or failure.
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Operation *op, const RewritePattern &pattern,
|
||||||
|
PatternRewriter &rewriter,
|
||||||
|
function_ref<bool(const Pattern &)> canApply,
|
||||||
|
function_ref<void(const Pattern &)> onFailure,
|
||||||
|
function_ref<LogicalResult(const Pattern &)> onSuccess);
|
||||||
|
|
||||||
|
/// The list that owns the patterns used within this applicator.
|
||||||
|
const OwningRewritePatternList &owningPatternList;
|
||||||
|
|
||||||
|
/// The set of patterns to match for each operation, stable sorted by benefit.
|
||||||
|
DenseMap<OperationName, SmallVector<RewritePattern *, 2>> patterns;
|
||||||
|
/// The set of patterns that may match against any operation type, stable
|
||||||
|
/// sorted by benefit.
|
||||||
|
SmallVector<RewritePattern *, 1> anyOpPatterns;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_REWRITE_PATTERNAPPLICATOR_H
|
52
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Normal file
52
mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file declares methods for applying a set of patterns greedily, choosing
|
||||||
|
// the patterns with the highest local benefit, until a fixed point is reached.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
|
||||||
|
#define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// applyPatternsGreedily
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Rewrite the regions of the specified operation, which must be isolated from
|
||||||
|
/// above, by repeatedly applying the highest benefit patterns in a greedy
|
||||||
|
/// work-list driven manner. Return success if no more patterns can be matched
|
||||||
|
/// in the result operation regions.
|
||||||
|
/// Note: This does not apply patterns to the top-level operation itself. Note:
|
||||||
|
/// These methods also perform folding and simple dead-code elimination
|
||||||
|
/// before attempting to match any of the provided patterns.
|
||||||
|
///
|
||||||
|
LogicalResult
|
||||||
|
applyPatternsAndFoldGreedily(Operation *op,
|
||||||
|
const OwningRewritePatternList &patterns);
|
||||||
|
/// Rewrite the given regions, which must be isolated from above.
|
||||||
|
LogicalResult
|
||||||
|
applyPatternsAndFoldGreedily(MutableArrayRef<Region> regions,
|
||||||
|
const OwningRewritePatternList &patterns);
|
||||||
|
|
||||||
|
/// Applies the specified patterns on `op` alone while also trying to fold it,
|
||||||
|
/// by selecting the highest benefits patterns in a greedy manner. Returns
|
||||||
|
/// success if no more patterns can be matched. `erased` is set to true if `op`
|
||||||
|
/// was folded away or erased as a result of becoming dead. Note: This does not
|
||||||
|
/// apply any patterns recursively to the regions of `op`.
|
||||||
|
LogicalResult applyOpPatternsAndFold(Operation *op,
|
||||||
|
const OwningRewritePatternList &patterns,
|
||||||
|
bool *erased = nullptr);
|
||||||
|
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_
|
@ -12,6 +12,7 @@ add_subdirectory(Interfaces)
|
|||||||
add_subdirectory(Parser)
|
add_subdirectory(Parser)
|
||||||
add_subdirectory(Pass)
|
add_subdirectory(Pass)
|
||||||
add_subdirectory(Reducer)
|
add_subdirectory(Reducer)
|
||||||
|
add_subdirectory(Rewrite)
|
||||||
add_subdirectory(Support)
|
add_subdirectory(Support)
|
||||||
add_subdirectory(TableGen)
|
add_subdirectory(TableGen)
|
||||||
add_subdirectory(Target)
|
add_subdirectory(Target)
|
||||||
|
@ -19,6 +19,7 @@
|
|||||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
#include "../GPUCommon/GPUOpsLowering.h"
|
#include "../GPUCommon/GPUOpsLowering.h"
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
#include "../GPUCommon/GPUOpsLowering.h"
|
#include "../GPUCommon/GPUOpsLowering.h"
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassRegistry.h"
|
#include "mlir/Pass/PassRegistry.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@ -123,7 +124,7 @@ class ConvertShapeConstraints
|
|||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
populateConvertShapeConstraintsConversionPatterns(patterns, context);
|
populateConvertShapeConstraintsConversionPatterns(patterns, context);
|
||||||
|
|
||||||
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
|
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -17,8 +17,8 @@
|
|||||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -15,16 +15,12 @@
|
|||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/IR/Types.h"
|
|
||||||
#include "mlir/Target/LLVMIR/TypeTranslation.h"
|
#include "mlir/Target/LLVMIR/TypeTranslation.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "llvm/IR/DerivedTypes.h"
|
#include "llvm/IR/DerivedTypes.h"
|
||||||
#include "llvm/IR/Module.h"
|
#include "llvm/IR/Module.h"
|
||||||
|
@ -24,14 +24,10 @@
|
|||||||
#include "mlir/Dialect/Vector/VectorUtils.h"
|
#include "mlir/Dialect/Vector/VectorUtils.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Location.h"
|
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/OperationSupport.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/Types.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Affine/Passes.h"
|
#include "mlir/Dialect/Affine/Passes.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/LoopUtils.h"
|
#include "mlir/Transforms/LoopUtils.h"
|
||||||
#include "llvm/ADT/MapVector.h"
|
#include "llvm/ADT/MapVector.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Affine/Passes.h"
|
#include "mlir/Dialect/Affine/Passes.h"
|
||||||
#include "mlir/IR/IntegerSet.h"
|
#include "mlir/IR/IntegerSet.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/Utils.h"
|
#include "mlir/Transforms/Utils.h"
|
||||||
|
|
||||||
#define DEBUG_TYPE "simplify-affine-structure"
|
#define DEBUG_TYPE "simplify-affine-structure"
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/IntegerSet.h"
|
#include "mlir/IR/IntegerSet.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/LoopUtils.h"
|
#include "mlir/Transforms/LoopUtils.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
@ -20,9 +20,8 @@
|
|||||||
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
@ -23,9 +23,9 @@
|
|||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Dominance.h"
|
#include "mlir/IR/Dominance.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
@ -22,6 +22,7 @@
|
|||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
|
@ -22,8 +22,8 @@
|
|||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineExprVisitor.h"
|
#include "mlir/IR/AffineExprVisitor.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/Support/CommandLine.h"
|
#include "llvm/Support/CommandLine.h"
|
||||||
|
|
||||||
|
@ -21,9 +21,9 @@
|
|||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
@ -12,10 +12,9 @@
|
|||||||
#include "mlir/Dialect/Quant/QuantizeUtils.h"
|
#include "mlir/Dialect/Quant/QuantizeUtils.h"
|
||||||
#include "mlir/Dialect/Quant/UniformSupport.h"
|
#include "mlir/Dialect/Quant/UniformSupport.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::quant;
|
using namespace mlir::quant;
|
||||||
|
@ -11,9 +11,8 @@
|
|||||||
#include "mlir/Dialect/Quant/Passes.h"
|
#include "mlir/Dialect/Quant/Passes.h"
|
||||||
#include "mlir/Dialect/Quant/QuantOps.h"
|
#include "mlir/Dialect/Quant/QuantOps.h"
|
||||||
#include "mlir/Dialect/Quant/UniformSupport.h"
|
#include "mlir/Dialect/Quant/UniformSupport.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::quant;
|
using namespace mlir::quant;
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
#include "mlir/Dialect/Shape/Transforms/Passes.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -8,14 +8,9 @@
|
|||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/Value.h"
|
|
||||||
#include "llvm/Support/Debug.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
#define DEBUG_TYPE "pattern-match"
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// PatternBenefit
|
// PatternBenefit
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
@ -205,135 +200,3 @@ void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) {
|
|||||||
cloneRegionBefore(region, *before->getParent(), before->getIterator());
|
cloneRegionBefore(region, *before->getParent(), before->getIterator());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// PatternApplicator
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
void PatternApplicator::applyCostModel(CostModel model) {
|
|
||||||
// Separate patterns by root kind to simplify lookup later on.
|
|
||||||
patterns.clear();
|
|
||||||
anyOpPatterns.clear();
|
|
||||||
for (const auto &pat : owningPatternList) {
|
|
||||||
// If the pattern is always impossible to match, just ignore it.
|
|
||||||
if (pat->getBenefit().isImpossibleToMatch()) {
|
|
||||||
LLVM_DEBUG({
|
|
||||||
llvm::dbgs()
|
|
||||||
<< "Ignoring pattern '" << pat->getRootKind()
|
|
||||||
<< "' because it is impossible to match (by pattern benefit)\n";
|
|
||||||
});
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (Optional<OperationName> opName = pat->getRootKind())
|
|
||||||
patterns[*opName].push_back(pat.get());
|
|
||||||
else
|
|
||||||
anyOpPatterns.push_back(pat.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort the patterns using the provided cost model.
|
|
||||||
llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
|
|
||||||
auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
|
|
||||||
return benefits[lhs] > benefits[rhs];
|
|
||||||
};
|
|
||||||
auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
|
|
||||||
// Special case for one pattern in the list, which is the most common case.
|
|
||||||
if (list.size() == 1) {
|
|
||||||
if (model(*list.front()).isImpossibleToMatch()) {
|
|
||||||
LLVM_DEBUG({
|
|
||||||
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
|
|
||||||
<< "' because it is impossible to match or cannot lead "
|
|
||||||
"to legal IR (by cost model)\n";
|
|
||||||
});
|
|
||||||
list.clear();
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect the dynamic benefits for the current pattern list.
|
|
||||||
benefits.clear();
|
|
||||||
for (RewritePattern *pat : list)
|
|
||||||
benefits.try_emplace(pat, model(*pat));
|
|
||||||
|
|
||||||
// Sort patterns with highest benefit first, and remove those that are
|
|
||||||
// impossible to match.
|
|
||||||
std::stable_sort(list.begin(), list.end(), cmp);
|
|
||||||
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
|
|
||||||
LLVM_DEBUG({
|
|
||||||
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
|
|
||||||
<< "' because it is impossible to match or cannot lead to "
|
|
||||||
"legal IR (by cost model)\n";
|
|
||||||
});
|
|
||||||
list.pop_back();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
for (auto &it : patterns)
|
|
||||||
processPatternList(it.second);
|
|
||||||
processPatternList(anyOpPatterns);
|
|
||||||
}
|
|
||||||
|
|
||||||
void PatternApplicator::walkAllPatterns(
|
|
||||||
function_ref<void(const Pattern &)> walk) {
|
|
||||||
for (auto &it : owningPatternList)
|
|
||||||
walk(*it);
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult PatternApplicator::matchAndRewrite(
|
|
||||||
Operation *op, PatternRewriter &rewriter,
|
|
||||||
function_ref<bool(const Pattern &)> canApply,
|
|
||||||
function_ref<void(const Pattern &)> onFailure,
|
|
||||||
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
|
||||||
// Check to see if there are patterns matching this specific operation type.
|
|
||||||
MutableArrayRef<RewritePattern *> opPatterns;
|
|
||||||
auto patternIt = patterns.find(op->getName());
|
|
||||||
if (patternIt != patterns.end())
|
|
||||||
opPatterns = patternIt->second;
|
|
||||||
|
|
||||||
// Process the patterns for that match the specific operation type, and any
|
|
||||||
// operation type in an interleaved fashion.
|
|
||||||
// FIXME: It'd be nice to just write an llvm::make_merge_range utility
|
|
||||||
// and pass in a comparison function. That would make this code trivial.
|
|
||||||
auto opIt = opPatterns.begin(), opE = opPatterns.end();
|
|
||||||
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
|
|
||||||
while (opIt != opE && anyIt != anyE) {
|
|
||||||
// Try to match the pattern providing the most benefit.
|
|
||||||
RewritePattern *pattern;
|
|
||||||
if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
|
|
||||||
pattern = *(opIt++);
|
|
||||||
else
|
|
||||||
pattern = *(anyIt++);
|
|
||||||
|
|
||||||
// Otherwise, try to match the generic pattern.
|
|
||||||
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
|
|
||||||
onSuccess)))
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
// If we break from the loop, then only one of the ranges can still have
|
|
||||||
// elements. Loop over both without checking given that we don't need to
|
|
||||||
// interleave anymore.
|
|
||||||
for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
|
|
||||||
llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
|
|
||||||
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
|
|
||||||
onSuccess)))
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult PatternApplicator::matchAndRewrite(
|
|
||||||
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
|
|
||||||
function_ref<bool(const Pattern &)> canApply,
|
|
||||||
function_ref<void(const Pattern &)> onFailure,
|
|
||||||
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
|
||||||
// Check that the pattern can be applied.
|
|
||||||
if (canApply && !canApply(pattern))
|
|
||||||
return failure();
|
|
||||||
|
|
||||||
// Try to match and rewrite this pattern. The patterns are sorted by
|
|
||||||
// benefit, so if we match we can immediately rewrite.
|
|
||||||
rewriter.setInsertionPoint(op);
|
|
||||||
if (succeeded(pattern.matchAndRewrite(op, rewriter)))
|
|
||||||
return success(!onSuccess || succeeded(onSuccess(pattern)));
|
|
||||||
|
|
||||||
if (onFailure)
|
|
||||||
onFailure(pattern);
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
|
12
mlir/lib/Rewrite/CMakeLists.txt
Normal file
12
mlir/lib/Rewrite/CMakeLists.txt
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
add_mlir_library(MLIRRewrite
|
||||||
|
PatternApplicator.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${MLIR_MAIN_INCLUDE_DIR}/mlir/Rewrite
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
mlir-generic-headers
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
)
|
148
mlir/lib/Rewrite/PatternApplicator.cpp
Normal file
148
mlir/lib/Rewrite/PatternApplicator.cpp
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
//===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements an applicator that applies pattern rewrites based upon a
|
||||||
|
// user defined cost model.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Rewrite/PatternApplicator.h"
|
||||||
|
#include "llvm/Support/Debug.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "pattern-match"
|
||||||
|
|
||||||
|
void PatternApplicator::applyCostModel(CostModel model) {
|
||||||
|
// Separate patterns by root kind to simplify lookup later on.
|
||||||
|
patterns.clear();
|
||||||
|
anyOpPatterns.clear();
|
||||||
|
for (const auto &pat : owningPatternList) {
|
||||||
|
// If the pattern is always impossible to match, just ignore it.
|
||||||
|
if (pat->getBenefit().isImpossibleToMatch()) {
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs()
|
||||||
|
<< "Ignoring pattern '" << pat->getRootKind()
|
||||||
|
<< "' because it is impossible to match (by pattern benefit)\n";
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (Optional<OperationName> opName = pat->getRootKind())
|
||||||
|
patterns[*opName].push_back(pat.get());
|
||||||
|
else
|
||||||
|
anyOpPatterns.push_back(pat.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the patterns using the provided cost model.
|
||||||
|
llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
|
||||||
|
auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
|
||||||
|
return benefits[lhs] > benefits[rhs];
|
||||||
|
};
|
||||||
|
auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
|
||||||
|
// Special case for one pattern in the list, which is the most common case.
|
||||||
|
if (list.size() == 1) {
|
||||||
|
if (model(*list.front()).isImpossibleToMatch()) {
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
|
||||||
|
<< "' because it is impossible to match or cannot lead "
|
||||||
|
"to legal IR (by cost model)\n";
|
||||||
|
});
|
||||||
|
list.clear();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect the dynamic benefits for the current pattern list.
|
||||||
|
benefits.clear();
|
||||||
|
for (RewritePattern *pat : list)
|
||||||
|
benefits.try_emplace(pat, model(*pat));
|
||||||
|
|
||||||
|
// Sort patterns with highest benefit first, and remove those that are
|
||||||
|
// impossible to match.
|
||||||
|
std::stable_sort(list.begin(), list.end(), cmp);
|
||||||
|
while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
|
||||||
|
LLVM_DEBUG({
|
||||||
|
llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
|
||||||
|
<< "' because it is impossible to match or cannot lead to "
|
||||||
|
"legal IR (by cost model)\n";
|
||||||
|
});
|
||||||
|
list.pop_back();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (auto &it : patterns)
|
||||||
|
processPatternList(it.second);
|
||||||
|
processPatternList(anyOpPatterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PatternApplicator::walkAllPatterns(
|
||||||
|
function_ref<void(const Pattern &)> walk) {
|
||||||
|
for (auto &it : owningPatternList)
|
||||||
|
walk(*it);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult PatternApplicator::matchAndRewrite(
|
||||||
|
Operation *op, PatternRewriter &rewriter,
|
||||||
|
function_ref<bool(const Pattern &)> canApply,
|
||||||
|
function_ref<void(const Pattern &)> onFailure,
|
||||||
|
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
||||||
|
// Check to see if there are patterns matching this specific operation type.
|
||||||
|
MutableArrayRef<RewritePattern *> opPatterns;
|
||||||
|
auto patternIt = patterns.find(op->getName());
|
||||||
|
if (patternIt != patterns.end())
|
||||||
|
opPatterns = patternIt->second;
|
||||||
|
|
||||||
|
// Process the patterns for that match the specific operation type, and any
|
||||||
|
// operation type in an interleaved fashion.
|
||||||
|
// FIXME: It'd be nice to just write an llvm::make_merge_range utility
|
||||||
|
// and pass in a comparison function. That would make this code trivial.
|
||||||
|
auto opIt = opPatterns.begin(), opE = opPatterns.end();
|
||||||
|
auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
|
||||||
|
while (opIt != opE && anyIt != anyE) {
|
||||||
|
// Try to match the pattern providing the most benefit.
|
||||||
|
RewritePattern *pattern;
|
||||||
|
if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
|
||||||
|
pattern = *(opIt++);
|
||||||
|
else
|
||||||
|
pattern = *(anyIt++);
|
||||||
|
|
||||||
|
// Otherwise, try to match the generic pattern.
|
||||||
|
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
|
||||||
|
onSuccess)))
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
// If we break from the loop, then only one of the ranges can still have
|
||||||
|
// elements. Loop over both without checking given that we don't need to
|
||||||
|
// interleave anymore.
|
||||||
|
for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
|
||||||
|
llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
|
||||||
|
if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
|
||||||
|
onSuccess)))
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult PatternApplicator::matchAndRewrite(
|
||||||
|
Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
|
||||||
|
function_ref<bool(const Pattern &)> canApply,
|
||||||
|
function_ref<void(const Pattern &)> onFailure,
|
||||||
|
function_ref<LogicalResult(const Pattern &)> onSuccess) {
|
||||||
|
// Check that the pattern can be applied.
|
||||||
|
if (canApply && !canApply(pattern))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Try to match and rewrite this pattern. The patterns are sorted by
|
||||||
|
// benefit, so if we match we can immediately rewrite.
|
||||||
|
rewriter.setInsertionPoint(op);
|
||||||
|
if (succeeded(pattern.matchAndRewrite(op, rewriter)))
|
||||||
|
return success(!onSuccess || succeeded(onSuccess(pattern)));
|
||||||
|
|
||||||
|
if (onFailure)
|
||||||
|
onFailure(pattern);
|
||||||
|
return failure();
|
||||||
|
}
|
@ -7,7 +7,6 @@ add_mlir_library(MLIRTransforms
|
|||||||
Canonicalizer.cpp
|
Canonicalizer.cpp
|
||||||
CopyRemoval.cpp
|
CopyRemoval.cpp
|
||||||
CSE.cpp
|
CSE.cpp
|
||||||
DialectConversion.cpp
|
|
||||||
Inliner.cpp
|
Inliner.cpp
|
||||||
LocationSnapshot.cpp
|
LocationSnapshot.cpp
|
||||||
LoopCoalescing.cpp
|
LoopCoalescing.cpp
|
||||||
|
@ -12,8 +12,8 @@
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "PassDetail.h"
|
#include "PassDetail.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
#include "mlir/Analysis/CallGraph.h"
|
#include "mlir/Analysis/CallGraph.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/InliningUtils.h"
|
#include "mlir/Transforms/InliningUtils.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "llvm/ADT/SCCIterator.h"
|
#include "llvm/ADT/SCCIterator.h"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
add_mlir_library(MLIRTransformUtils
|
add_mlir_library(MLIRTransformUtils
|
||||||
|
DialectConversion.cpp
|
||||||
FoldUtils.cpp
|
FoldUtils.cpp
|
||||||
GreedyPatternRewriteDriver.cpp
|
GreedyPatternRewriteDriver.cpp
|
||||||
InliningUtils.cpp
|
InliningUtils.cpp
|
||||||
@ -19,5 +20,6 @@ add_mlir_library(MLIRTransformUtils
|
|||||||
MLIRLoopAnalysis
|
MLIRLoopAnalysis
|
||||||
MLIRSCF
|
MLIRSCF
|
||||||
MLIRPass
|
MLIRPass
|
||||||
|
MLIRRewrite
|
||||||
MLIRStandard
|
MLIRStandard
|
||||||
)
|
)
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/Module.h"
|
#include "mlir/IR/Module.h"
|
||||||
|
#include "mlir/Rewrite/PatternApplicator.h"
|
||||||
#include "mlir/Transforms/Utils.h"
|
#include "mlir/Transforms/Utils.h"
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
@ -74,8 +75,7 @@ computeConversionSet(iterator_range<Region::iterator> region,
|
|||||||
|
|
||||||
/// A utility function to log a successful result for the given reason.
|
/// A utility function to log a successful result for the given reason.
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
|
static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
|
||||||
Args &&... args) {
|
|
||||||
LLVM_DEBUG({
|
LLVM_DEBUG({
|
||||||
os.unindent();
|
os.unindent();
|
||||||
os.startLine() << "} -> SUCCESS";
|
os.startLine() << "} -> SUCCESS";
|
||||||
@ -88,8 +88,7 @@ static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
|
|||||||
|
|
||||||
/// A utility function to log a failure result for the given reason.
|
/// A utility function to log a failure result for the given reason.
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
|
static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
|
||||||
Args &&... args) {
|
|
||||||
LLVM_DEBUG({
|
LLVM_DEBUG({
|
||||||
os.unindent();
|
os.unindent();
|
||||||
os.startLine() << "} -> FAILURE : "
|
os.startLine() << "} -> FAILURE : "
|
||||||
@ -2033,21 +2032,21 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
|
|||||||
return minDepth;
|
return minDepth;
|
||||||
|
|
||||||
// Sort the patterns by those likely to be the most beneficial.
|
// Sort the patterns by those likely to be the most beneficial.
|
||||||
llvm::array_pod_sort(
|
llvm::array_pod_sort(patternsByDepth.begin(), patternsByDepth.end(),
|
||||||
patternsByDepth.begin(), patternsByDepth.end(),
|
[](const std::pair<const Pattern *, unsigned> *lhs,
|
||||||
[](const std::pair<const Pattern *, unsigned> *lhs,
|
const std::pair<const Pattern *, unsigned> *rhs) {
|
||||||
const std::pair<const Pattern *, unsigned> *rhs) {
|
// First sort by the smaller pattern legalization
|
||||||
// First sort by the smaller pattern legalization depth.
|
// depth.
|
||||||
if (lhs->second != rhs->second)
|
if (lhs->second != rhs->second)
|
||||||
return llvm::array_pod_sort_comparator<unsigned>(&lhs->second,
|
return llvm::array_pod_sort_comparator<unsigned>(
|
||||||
&rhs->second);
|
&lhs->second, &rhs->second);
|
||||||
|
|
||||||
// Then sort by the larger pattern benefit.
|
// Then sort by the larger pattern benefit.
|
||||||
auto lhsBenefit = lhs->first->getBenefit();
|
auto lhsBenefit = lhs->first->getBenefit();
|
||||||
auto rhsBenefit = rhs->first->getBenefit();
|
auto rhsBenefit = rhs->first->getBenefit();
|
||||||
return llvm::array_pod_sort_comparator<PatternBenefit>(&rhsBenefit,
|
return llvm::array_pod_sort_comparator<PatternBenefit>(
|
||||||
&lhsBenefit);
|
&rhsBenefit, &lhsBenefit);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Update the legalization pattern to use the new sorted list.
|
// Update the legalization pattern to use the new sorted list.
|
||||||
patterns.clear();
|
patterns.clear();
|
@ -10,8 +10,9 @@
|
|||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "mlir/Rewrite/PatternApplicator.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
#include "mlir/Transforms/RegionUtils.h"
|
#include "mlir/Transforms/RegionUtils.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
@ -23,8 +23,8 @@
|
|||||||
#include "mlir/IR/BlockAndValueMapping.h"
|
#include "mlir/IR/BlockAndValueMapping.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/IntegerSet.h"
|
#include "mlir/IR/IntegerSet.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Support/MathExtras.h"
|
#include "mlir/Support/MathExtras.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/RegionUtils.h"
|
#include "mlir/Transforms/RegionUtils.h"
|
||||||
#include "mlir/Transforms/Utils.h"
|
#include "mlir/Transforms/Utils.h"
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
|
@ -13,8 +13,8 @@
|
|||||||
|
|
||||||
#include "mlir/Analysis/Utils.h"
|
#include "mlir/Analysis/Utils.h"
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/LoopUtils.h"
|
#include "mlir/Transforms/LoopUtils.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
|
@ -10,10 +10,10 @@
|
|||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -7,9 +7,8 @@
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "TestDialect.h"
|
#include "TestDialect.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@ -25,9 +24,9 @@ OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold(
|
|||||||
|
|
||||||
OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
|
OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold(
|
||||||
ArrayRef<Attribute> operands) {
|
ArrayRef<Attribute> operands) {
|
||||||
auto argument_op = getOperand();
|
auto argumentOp = getOperand();
|
||||||
// The success case should cause the trait fold to be supressed.
|
// The success case should cause the trait fold to be supressed.
|
||||||
return argument_op.getDefiningOp() ? argument_op : OpFoldResult{};
|
return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{};
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "mlir/Transforms/LoopUtils.h"
|
#include "mlir/Transforms/LoopUtils.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
@ -93,7 +94,7 @@ void TestConvVectorization::runOnOperation() {
|
|||||||
// VectorTransforms.cpp
|
// VectorTransforms.cpp
|
||||||
vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
|
vectorTransferPatterns.insert<VectorTransferFullPartialRewriter>(
|
||||||
context, vectorTransformsOptions);
|
context, vectorTransformsOptions);
|
||||||
applyPatternsAndFoldGreedily(module, vectorTransferPatterns);
|
applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
|
||||||
|
|
||||||
// Programmatic controlled lowering of linalg.copy and linalg.fill.
|
// Programmatic controlled lowering of linalg.copy and linalg.fill.
|
||||||
PassManager pm(context);
|
PassManager pm(context);
|
||||||
@ -105,13 +106,14 @@ void TestConvVectorization::runOnOperation() {
|
|||||||
OwningRewritePatternList vectorContractLoweringPatterns;
|
OwningRewritePatternList vectorContractLoweringPatterns;
|
||||||
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
|
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
|
||||||
context, vectorTransformsOptions);
|
context, vectorTransformsOptions);
|
||||||
applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns);
|
applyPatternsAndFoldGreedily(module,
|
||||||
|
std::move(vectorContractLoweringPatterns));
|
||||||
|
|
||||||
// Programmatic controlled lowering of vector.transfer only.
|
// Programmatic controlled lowering of vector.transfer only.
|
||||||
OwningRewritePatternList vectorToLoopsPatterns;
|
OwningRewritePatternList vectorToLoopsPatterns;
|
||||||
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
|
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
|
||||||
VectorTransferToSCFOptions());
|
VectorTransferToSCFOptions());
|
||||||
applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns);
|
applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
|
||||||
|
|
||||||
// Ensure we drop the marker in the end.
|
// Ensure we drop the marker in the end.
|
||||||
module.walk([](linalg::LinalgOp op) {
|
module.walk([](linalg::LinalgOp op) {
|
||||||
|
@ -11,8 +11,8 @@
|
|||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -12,8 +12,8 @@
|
|||||||
|
|
||||||
#include "mlir/Dialect/GPU/Passes.h"
|
#include "mlir/Dialect/GPU/Passes.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::linalg;
|
using namespace mlir::linalg;
|
||||||
|
@ -17,8 +17,8 @@
|
|||||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SetVector.h"
|
#include "llvm/ADT/SetVector.h"
|
||||||
|
|
||||||
|
@ -14,9 +14,8 @@
|
|||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user