[mlir] Reapply 141423 mlir-query combinators plus fix (#146156)
An uninitialized variable that caused a crash (https://lab.llvm.org/buildbot/#/builders/164/builds/11004) was identified using the memory analyzer, leading to the reversion of https://github.com/llvm/llvm-project/pull/141423. This pull request reapplies the previously reverted changes and includes the fix, which has been tested locally following the steps at https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild. Note: the fix is included as part of the second commit
This commit is contained in:
parent
771ee8e387
commit
3702d64801
@ -108,6 +108,9 @@ public:
|
||||
const llvm::ArrayRef<ParserValue> args,
|
||||
Diagnostics *error) const = 0;
|
||||
|
||||
// If the matcher is variadic, it can take any number of arguments.
|
||||
virtual bool isVariadic() const = 0;
|
||||
|
||||
// Returns the number of arguments accepted by the matcher.
|
||||
virtual unsigned getNumArgs() const = 0;
|
||||
|
||||
@ -140,6 +143,8 @@ public:
|
||||
return marshaller(matcherFunc, matcherName, nameRange, args, error);
|
||||
}
|
||||
|
||||
bool isVariadic() const override { return false; }
|
||||
|
||||
unsigned getNumArgs() const override { return argKinds.size(); }
|
||||
|
||||
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
|
||||
@ -153,6 +158,54 @@ private:
|
||||
const std::vector<ArgKind> argKinds;
|
||||
};
|
||||
|
||||
class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
|
||||
public:
|
||||
using VarOp = DynMatcher::VariadicOperator;
|
||||
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
|
||||
VarOp varOp, StringRef matcherName)
|
||||
: minCount(minCount), maxCount(maxCount), varOp(varOp),
|
||||
matcherName(matcherName) {}
|
||||
|
||||
VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
|
||||
Diagnostics *error) const override {
|
||||
if (args.size() < minCount || maxCount < args.size()) {
|
||||
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
|
||||
{llvm::Twine("requires between "), llvm::Twine(minCount),
|
||||
llvm::Twine(" and "), llvm::Twine(maxCount),
|
||||
llvm::Twine(" args, got "), llvm::Twine(args.size())});
|
||||
return VariantMatcher();
|
||||
}
|
||||
|
||||
std::vector<VariantMatcher> innerArgs;
|
||||
for (int64_t i = 0, e = args.size(); i != e; ++i) {
|
||||
const ParserValue &arg = args[i];
|
||||
const VariantValue &value = arg.value;
|
||||
if (!value.isMatcher()) {
|
||||
addError(error, arg.range, ErrorType::RegistryWrongArgType,
|
||||
{llvm::Twine(i + 1), llvm::Twine("matcher: "),
|
||||
llvm::Twine(value.getTypeAsString())});
|
||||
return VariantMatcher();
|
||||
}
|
||||
innerArgs.push_back(value.getMatcher());
|
||||
}
|
||||
return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
|
||||
}
|
||||
|
||||
bool isVariadic() const override { return true; }
|
||||
|
||||
unsigned getNumArgs() const override { return 0; }
|
||||
|
||||
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
|
||||
kinds.push_back(ArgKind(ArgKind::Matcher));
|
||||
}
|
||||
|
||||
private:
|
||||
const unsigned minCount;
|
||||
const unsigned maxCount;
|
||||
const VarOp varOp;
|
||||
const StringRef matcherName;
|
||||
};
|
||||
|
||||
// Helper function to check if argument count matches expected count
|
||||
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
|
||||
llvm::ArrayRef<ParserValue> args,
|
||||
@ -224,6 +277,14 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
|
||||
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
|
||||
}
|
||||
|
||||
// Variadic operator overload.
|
||||
template <unsigned MinCount, unsigned MaxCount>
|
||||
std::unique_ptr<MatcherDescriptor>
|
||||
makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
|
||||
StringRef matcherName) {
|
||||
return std::make_unique<VariadicOperatorMatcherDescriptor>(
|
||||
MinCount, MaxCount, func.varOp, matcherName);
|
||||
}
|
||||
} // namespace mlir::query::matcher::internal
|
||||
|
||||
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
|
||||
|
@ -21,7 +21,9 @@
|
||||
|
||||
namespace mlir::query::matcher {
|
||||
|
||||
/// A class that provides utilities to find operations in the IR.
|
||||
/// Finds and collects matches from the IR. After construction
|
||||
/// `collectMatches` can be used to traverse the IR and apply
|
||||
/// matchers.
|
||||
class MatchFinder {
|
||||
|
||||
public:
|
||||
|
@ -8,11 +8,11 @@
|
||||
//
|
||||
// Implements the base layer of the matcher framework.
|
||||
//
|
||||
// Matchers are methods that return a Matcher which provides a method one of the
|
||||
// following methods: match(Operation *op), match(Operation *op,
|
||||
// SetVector<Operation *> &matchedOps)
|
||||
// Matchers are methods that return a Matcher which provides a
|
||||
// `match(...)` method whose parameters define the context of the match.
|
||||
// Support includes simple (unary) matchers as well as matcher combinators
|
||||
// (anyOf, allOf, etc.)
|
||||
//
|
||||
// The matcher functions are defined in include/mlir/IR/Matchers.h.
|
||||
// This file contains the wrapper classes needed to construct matchers for
|
||||
// mlir-query.
|
||||
//
|
||||
@ -25,6 +25,15 @@
|
||||
#include "llvm/ADT/IntrusiveRefCntPtr.h"
|
||||
|
||||
namespace mlir::query::matcher {
|
||||
class DynMatcher;
|
||||
namespace internal {
|
||||
|
||||
bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
|
||||
ArrayRef<DynMatcher> innerMatchers);
|
||||
bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
|
||||
ArrayRef<DynMatcher> innerMatchers);
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Defaults to false if T has no match() method with the signature:
|
||||
// match(Operation* op).
|
||||
@ -84,6 +93,27 @@ private:
|
||||
MatcherFn matcherFn;
|
||||
};
|
||||
|
||||
// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
|
||||
// match the given operation.
|
||||
using VariadicOperatorFunction = bool (*)(Operation *op,
|
||||
SetVector<Operation *> *matchedOps,
|
||||
ArrayRef<DynMatcher> innerMatchers);
|
||||
|
||||
template <VariadicOperatorFunction Func>
|
||||
class VariadicMatcher : public MatcherInterface {
|
||||
public:
|
||||
VariadicMatcher(std::vector<DynMatcher> matchers)
|
||||
: matchers(std::move(matchers)) {}
|
||||
|
||||
bool match(Operation *op) override { return Func(op, nullptr, matchers); }
|
||||
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
|
||||
return Func(op, &matchedOps, matchers);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<DynMatcher> matchers;
|
||||
};
|
||||
|
||||
// Matcher wraps a MatcherInterface implementation and provides match()
|
||||
// methods that redirect calls to the underlying implementation.
|
||||
class DynMatcher {
|
||||
@ -92,6 +122,31 @@ public:
|
||||
DynMatcher(MatcherInterface *implementation)
|
||||
: implementation(implementation) {}
|
||||
|
||||
// Construct from a variadic function.
|
||||
enum VariadicOperator {
|
||||
// Matches operations for which all provided matchers match.
|
||||
AllOf,
|
||||
// Matches operations for which at least one of the provided matchers
|
||||
// matches.
|
||||
AnyOf
|
||||
};
|
||||
|
||||
static std::unique_ptr<DynMatcher>
|
||||
constructVariadic(VariadicOperator Op,
|
||||
std::vector<DynMatcher> innerMatchers) {
|
||||
switch (Op) {
|
||||
case AllOf:
|
||||
return std::make_unique<DynMatcher>(
|
||||
new VariadicMatcher<internal::allOfVariadicOperator>(
|
||||
std::move(innerMatchers)));
|
||||
case AnyOf:
|
||||
return std::make_unique<DynMatcher>(
|
||||
new VariadicMatcher<internal::anyOfVariadicOperator>(
|
||||
std::move(innerMatchers)));
|
||||
}
|
||||
llvm_unreachable("Invalid Op value.");
|
||||
}
|
||||
|
||||
template <typename MatcherFn>
|
||||
static std::unique_ptr<DynMatcher>
|
||||
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
|
||||
@ -113,6 +168,59 @@ private:
|
||||
std::string functionName;
|
||||
};
|
||||
|
||||
// VariadicOperatorMatcher related types.
|
||||
template <typename... Ps>
|
||||
class VariadicOperatorMatcher {
|
||||
public:
|
||||
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
|
||||
: varOp(varOp), params(std::forward<Ps>(params)...) {}
|
||||
|
||||
operator std::unique_ptr<DynMatcher>() const & {
|
||||
return DynMatcher::constructVariadic(
|
||||
varOp, getMatchers(std::index_sequence_for<Ps...>()));
|
||||
}
|
||||
|
||||
operator std::unique_ptr<DynMatcher>() && {
|
||||
return DynMatcher::constructVariadic(
|
||||
varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
|
||||
}
|
||||
|
||||
private:
|
||||
// Helper method to unpack the tuple into a vector.
|
||||
template <std::size_t... Is>
|
||||
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
|
||||
return {DynMatcher(std::get<Is>(params))...};
|
||||
}
|
||||
|
||||
template <std::size_t... Is>
|
||||
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
|
||||
return {DynMatcher(std::get<Is>(std::move(params)))...};
|
||||
}
|
||||
|
||||
const DynMatcher::VariadicOperator varOp;
|
||||
std::tuple<Ps...> params;
|
||||
};
|
||||
|
||||
// Overloaded function object to generate VariadicOperatorMatcher objects from
|
||||
// arbitrary matchers.
|
||||
template <unsigned MinCount, unsigned MaxCount>
|
||||
struct VariadicOperatorMatcherFunc {
|
||||
DynMatcher::VariadicOperator varOp;
|
||||
|
||||
template <typename... Ms>
|
||||
VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
|
||||
static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
|
||||
"invalid number of parameters for variadic matcher");
|
||||
return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
|
||||
}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
|
||||
anyOf = {DynMatcher::AnyOf};
|
||||
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
|
||||
allOf = {DynMatcher::AllOf};
|
||||
} // namespace internal
|
||||
} // namespace mlir::query::matcher
|
||||
|
||||
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
|
||||
|
@ -6,7 +6,8 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file provides matchers for MLIRQuery that peform slicing analysis
|
||||
// This file defines slicing-analysis matchers that extend and abstract the
|
||||
// core implementations from `SliceAnalysis.h`.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -16,9 +17,9 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
|
||||
/// Additionally, it limits the slice computation to a certain depth level using
|
||||
/// a custom filter.
|
||||
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
|
||||
/// if `innerMatcher` matches. The traversal stops once the desired depth level
|
||||
/// is reached.
|
||||
///
|
||||
/// Example: starting from node 9, assuming the matcher
|
||||
/// computes the slice for the first two depth levels:
|
||||
@ -119,6 +120,77 @@ bool BackwardSliceMatcher<Matcher>::matches(
|
||||
: backwardSlice.size() >= 1;
|
||||
}
|
||||
|
||||
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
|
||||
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
|
||||
template <typename BaseMatcher, typename Filter>
|
||||
class PredicateBackwardSliceMatcher {
|
||||
public:
|
||||
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
|
||||
bool inclusive, bool omitBlockArguments,
|
||||
bool omitUsesFromAbove)
|
||||
: innerMatcher(std::move(innerMatcher)),
|
||||
filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
|
||||
omitBlockArguments(omitBlockArguments),
|
||||
omitUsesFromAbove(omitUsesFromAbove) {}
|
||||
|
||||
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
|
||||
backwardSlice.clear();
|
||||
BackwardSliceOptions options;
|
||||
options.inclusive = inclusive;
|
||||
options.omitUsesFromAbove = omitUsesFromAbove;
|
||||
options.omitBlockArguments = omitBlockArguments;
|
||||
if (innerMatcher.match(rootOp)) {
|
||||
options.filter = [&](Operation *subOp) {
|
||||
return !filterMatcher.match(subOp);
|
||||
};
|
||||
LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
|
||||
assert(result.succeeded() && "expected backward slice to succeed");
|
||||
(void)result;
|
||||
return options.inclusive ? backwardSlice.size() > 1
|
||||
: backwardSlice.size() >= 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
BaseMatcher innerMatcher;
|
||||
Filter filterMatcher;
|
||||
bool inclusive;
|
||||
bool omitBlockArguments;
|
||||
bool omitUsesFromAbove;
|
||||
};
|
||||
|
||||
/// Computes the forward-slice of all users reachable from `rootOp`,
|
||||
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
|
||||
template <typename BaseMatcher, typename Filter>
|
||||
class PredicateForwardSliceMatcher {
|
||||
public:
|
||||
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
|
||||
bool inclusive)
|
||||
: innerMatcher(std::move(innerMatcher)),
|
||||
filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
|
||||
|
||||
bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
|
||||
forwardSlice.clear();
|
||||
ForwardSliceOptions options;
|
||||
options.inclusive = inclusive;
|
||||
if (innerMatcher.match(rootOp)) {
|
||||
options.filter = [&](Operation *subOp) {
|
||||
return !filterMatcher.match(subOp);
|
||||
};
|
||||
getForwardSlice(rootOp, &forwardSlice, options);
|
||||
return options.inclusive ? forwardSlice.size() > 1
|
||||
: forwardSlice.size() >= 1;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
BaseMatcher innerMatcher;
|
||||
Filter filterMatcher;
|
||||
bool inclusive;
|
||||
};
|
||||
|
||||
/// Matches transitive defs of a top-level operation up to N levels.
|
||||
template <typename Matcher>
|
||||
inline BackwardSliceMatcher<Matcher>
|
||||
@ -130,7 +202,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
|
||||
omitUsesFromAbove);
|
||||
}
|
||||
|
||||
/// Matches all transitive defs of a top-level operation up to N levels
|
||||
/// Matches all transitive defs of a top-level operation up to N levels.
|
||||
template <typename Matcher>
|
||||
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
|
||||
int64_t maxDepth) {
|
||||
@ -139,6 +211,28 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
|
||||
false, false);
|
||||
}
|
||||
|
||||
/// Matches all transitive defs of a top-level operation and stops where
|
||||
/// `filterMatcher` rejects.
|
||||
template <typename BaseMatcher, typename Filter>
|
||||
inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
|
||||
m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
|
||||
bool inclusive, bool omitBlockArguments,
|
||||
bool omitUsesFromAbove) {
|
||||
return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
|
||||
std::move(innerMatcher), std::move(filterMatcher), inclusive,
|
||||
omitBlockArguments, omitUsesFromAbove);
|
||||
}
|
||||
|
||||
/// Matches all users of a top-level operation and stops where
|
||||
/// `filterMatcher` rejects.
|
||||
template <typename BaseMatcher, typename Filter>
|
||||
inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
|
||||
m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
|
||||
bool inclusive) {
|
||||
return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
|
||||
std::move(innerMatcher), std::move(filterMatcher), inclusive);
|
||||
}
|
||||
|
||||
} // namespace mlir::query::matcher
|
||||
|
||||
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H
|
||||
|
@ -26,7 +26,12 @@ enum class ArgKind { Boolean, Matcher, Signed, String };
|
||||
// A variant matcher object to abstract simple and complex matchers into a
|
||||
// single object type.
|
||||
class VariantMatcher {
|
||||
class MatcherOps;
|
||||
class MatcherOps {
|
||||
public:
|
||||
std::optional<DynMatcher>
|
||||
constructVariadicOperator(DynMatcher::VariadicOperator varOp,
|
||||
ArrayRef<VariantMatcher> innerMatchers) const;
|
||||
};
|
||||
|
||||
// Payload interface to be specialized by each matcher type. It follows a
|
||||
// similar interface as VariantMatcher itself.
|
||||
@ -43,6 +48,9 @@ public:
|
||||
|
||||
// Clones the provided matcher.
|
||||
static VariantMatcher SingleMatcher(DynMatcher matcher);
|
||||
static VariantMatcher
|
||||
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
|
||||
ArrayRef<VariantMatcher> args);
|
||||
|
||||
// Makes the matcher the "null" matcher.
|
||||
void reset();
|
||||
@ -61,6 +69,7 @@ private:
|
||||
: value(std::move(value)) {}
|
||||
|
||||
class SinglePayload;
|
||||
class VariadicOpPayload;
|
||||
|
||||
std::shared_ptr<const Payload> value;
|
||||
};
|
||||
|
@ -1,5 +1,6 @@
|
||||
add_mlir_library(MLIRQueryMatcher
|
||||
MatchFinder.cpp
|
||||
MatchersInternal.cpp
|
||||
Parser.cpp
|
||||
RegistryManager.cpp
|
||||
VariantValue.cpp
|
||||
|
33
mlir/lib/Query/Matcher/MatchersInternal.cpp
Normal file
33
mlir/lib/Query/Matcher/MatchersInternal.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
//===--- MatchersInternal.cpp----------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Query/Matcher/MatchersInternal.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
|
||||
namespace mlir::query::matcher {
|
||||
|
||||
namespace internal {
|
||||
|
||||
bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
|
||||
ArrayRef<DynMatcher> innerMatchers) {
|
||||
return llvm::all_of(innerMatchers, [&](const DynMatcher &matcher) {
|
||||
if (matchedOps)
|
||||
return matcher.match(op, *matchedOps);
|
||||
return matcher.match(op);
|
||||
});
|
||||
}
|
||||
bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
|
||||
ArrayRef<DynMatcher> innerMatchers) {
|
||||
return llvm::any_of(innerMatchers, [&](const DynMatcher &matcher) {
|
||||
if (matchedOps)
|
||||
return matcher.match(op, *matchedOps);
|
||||
return matcher.match(op);
|
||||
});
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace mlir::query::matcher
|
@ -64,7 +64,7 @@ std::vector<ArgKind> RegistryManager::getAcceptedCompletionTypes(
|
||||
unsigned argNumber = ctxEntry.second;
|
||||
std::vector<ArgKind> nextTypeSet;
|
||||
|
||||
if (argNumber < ctor->getNumArgs())
|
||||
if (ctor->isVariadic() || argNumber < ctor->getNumArgs())
|
||||
ctor->getArgKinds(argNumber, nextTypeSet);
|
||||
|
||||
typeSet.insert(nextTypeSet.begin(), nextTypeSet.end());
|
||||
@ -83,7 +83,7 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
|
||||
const internal::MatcherDescriptor &matcher = *m.getValue();
|
||||
llvm::StringRef name = m.getKey();
|
||||
|
||||
unsigned numArgs = matcher.getNumArgs();
|
||||
unsigned numArgs = matcher.isVariadic() ? 1 : matcher.getNumArgs();
|
||||
std::vector<std::vector<ArgKind>> argKinds(numArgs);
|
||||
|
||||
for (const ArgKind &kind : acceptedTypes) {
|
||||
@ -115,6 +115,9 @@ RegistryManager::getMatcherCompletions(llvm::ArrayRef<ArgKind> acceptedTypes,
|
||||
}
|
||||
}
|
||||
|
||||
if (matcher.isVariadic())
|
||||
os << ",...";
|
||||
|
||||
os << ")";
|
||||
typedText += "(";
|
||||
|
||||
|
@ -27,12 +27,64 @@ private:
|
||||
DynMatcher matcher;
|
||||
};
|
||||
|
||||
class VariantMatcher::VariadicOpPayload : public VariantMatcher::Payload {
|
||||
public:
|
||||
VariadicOpPayload(DynMatcher::VariadicOperator varOp,
|
||||
std::vector<VariantMatcher> args)
|
||||
: varOp(varOp), args(std::move(args)) {}
|
||||
|
||||
std::optional<DynMatcher> getDynMatcher() const override {
|
||||
std::vector<DynMatcher> dynMatchers;
|
||||
for (auto variantMatcher : args) {
|
||||
std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher();
|
||||
if (dynMatcher)
|
||||
dynMatchers.push_back(dynMatcher.value());
|
||||
}
|
||||
auto result = DynMatcher::constructVariadic(varOp, dynMatchers);
|
||||
return *result;
|
||||
}
|
||||
|
||||
std::string getTypeAsString() const override {
|
||||
std::string inner;
|
||||
llvm::interleave(
|
||||
args, [&](auto const &arg) { inner += arg.getTypeAsString(); },
|
||||
[&] { inner += " & "; });
|
||||
return inner;
|
||||
}
|
||||
|
||||
private:
|
||||
const DynMatcher::VariadicOperator varOp;
|
||||
const std::vector<VariantMatcher> args;
|
||||
};
|
||||
|
||||
VariantMatcher::VariantMatcher() = default;
|
||||
|
||||
VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) {
|
||||
return VariantMatcher(std::make_shared<SinglePayload>(std::move(matcher)));
|
||||
}
|
||||
|
||||
VariantMatcher
|
||||
VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp,
|
||||
ArrayRef<VariantMatcher> args) {
|
||||
return VariantMatcher(
|
||||
std::make_shared<VariadicOpPayload>(varOp, std::move(args)));
|
||||
}
|
||||
|
||||
std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator(
|
||||
DynMatcher::VariadicOperator varOp,
|
||||
ArrayRef<VariantMatcher> innerMatchers) const {
|
||||
std::vector<DynMatcher> dynMatchers;
|
||||
for (const auto &innerMatcher : innerMatchers) {
|
||||
if (!innerMatcher.value)
|
||||
return std::nullopt;
|
||||
std::optional<DynMatcher> inner = innerMatcher.value->getDynMatcher();
|
||||
if (!inner)
|
||||
return std::nullopt;
|
||||
dynMatchers.push_back(*inner);
|
||||
}
|
||||
return *DynMatcher::constructVariadic(varOp, dynMatchers);
|
||||
}
|
||||
|
||||
std::optional<DynMatcher> VariantMatcher::getDynMatcher() const {
|
||||
return value ? value->getDynMatcher() : std::nullopt;
|
||||
}
|
||||
@ -120,11 +172,11 @@ void VariantValue::setSigned(int64_t newValue) {
|
||||
// Boolean
|
||||
bool VariantValue::isBoolean() const { return type == ValueType::Boolean; }
|
||||
|
||||
bool VariantValue::getBoolean() const { return value.Signed; }
|
||||
bool VariantValue::getBoolean() const { return value.Boolean; }
|
||||
|
||||
void VariantValue::setBoolean(bool newValue) {
|
||||
type = ValueType::Boolean;
|
||||
value.Signed = newValue;
|
||||
value.Boolean = newValue;
|
||||
}
|
||||
|
||||
bool VariantValue::isString() const { return type == ValueType::String; }
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include "QueryParser.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Query/Matcher/MatchFinder.h"
|
||||
#include "mlir/Query/QuerySession.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
@ -68,6 +69,8 @@ static Operation *extractFunction(std::vector<Operation *> &ops,
|
||||
// Clone operations and build function body
|
||||
std::vector<Operation *> clonedOps;
|
||||
std::vector<Value> clonedVals;
|
||||
// TODO: Handle extraction of operations with compute payloads defined via
|
||||
// regions.
|
||||
for (Operation *slicedOp : slice) {
|
||||
Operation *clonedOp =
|
||||
clonedOps.emplace_back(builder.clone(*slicedOp, mapper));
|
||||
@ -129,6 +132,8 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const {
|
||||
finder.flattenMatchedOps(matches);
|
||||
Operation *function =
|
||||
extractFunction(flattenedMatches, rootOp->getContext(), functionName);
|
||||
if (failed(verify(function)))
|
||||
return mlir::failure();
|
||||
os << "\n" << *function << "\n\n";
|
||||
function->erase();
|
||||
return mlir::success();
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-query %s -c "m getAllDefinitions(hasOpName(\"arith.addf\"),2)" | FileCheck %s
|
||||
// RUN: mlir-query %s -c "m anyOf(getAllDefinitions(hasOpName(\"arith.addf\"),2),getAllDefinitions(hasOpName(\"tensor.extract\"),1))" | FileCheck %s
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) {
|
||||
@ -19,14 +19,23 @@ func.func @slice_use_from_above(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>)
|
||||
}
|
||||
|
||||
// CHECK: Match #1:
|
||||
|
||||
// CHECK: %[[LINALG:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]}
|
||||
// CHECK-SAME: ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>)
|
||||
|
||||
// CHECK: {{.*}}.mlir:7:10: note: "root" binds here
|
||||
// CHECK: %[[ADDF1:.*]] = arith.addf %in, %in : f32
|
||||
|
||||
// CHECK: Match #2:
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
|
||||
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
|
||||
|
||||
// CHECK: {{.*}}.mlir:14:18: note: "root" binds here
|
||||
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
|
||||
|
||||
// CHECK: Match #3:
|
||||
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[LINALG]] {{\[\[.*\]\]}} : tensor<5x5xf32> into tensor<25xf32>
|
||||
// CHECK: %[[C2:.*]] = arith.constant {{.*}} : index
|
||||
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[COLLAPSED]][%[[C2]]] : tensor<25xf32>
|
||||
|
||||
// CHECK: {{.*}}.mlir:15:10: note: "root" binds here
|
||||
// CHECK: %[[ADDF2:.*]] = arith.addf %[[EXTRACTED]], %[[EXTRACTED]] : f32
|
27
mlir/test/mlir-query/forward-slice-by-predicate.mlir
Normal file
27
mlir/test/mlir-query/forward-slice-by-predicate.mlir
Normal file
@ -0,0 +1,27 @@
|
||||
// RUN: mlir-query %s -c "m getUsersByPredicate(anyOf(hasOpName(\"memref.alloc\"),isConstantOp()),anyOf(hasOpName(\"affine.load\"), hasOpName(\"memref.dealloc\")),true)" | FileCheck %s
|
||||
|
||||
func.func @slice_depth1_loop_nest_with_offsets() {
|
||||
%0 = memref.alloc() : memref<100xf32>
|
||||
%cst = arith.constant 7.000000e+00 : f32
|
||||
affine.for %i0 = 0 to 16 {
|
||||
%a0 = affine.apply affine_map<(d0) -> (d0 + 2)>(%i0)
|
||||
affine.store %cst, %0[%a0] : memref<100xf32>
|
||||
}
|
||||
affine.for %i1 = 4 to 8 {
|
||||
%a1 = affine.apply affine_map<(d0) -> (d0 - 1)>(%i1)
|
||||
%1 = affine.load %0[%a1] : memref<100xf32>
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: Match #1:
|
||||
// CHECK: {{.*}}.mlir:4:8: note: "root" binds here
|
||||
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<100xf32>
|
||||
|
||||
// CHECK: affine.store %cst, %0[%a0] : memref<100xf32>
|
||||
|
||||
// CHECK: Match #2:
|
||||
// CHECK: {{.*}}.mlir:5:10: note: "root" binds here
|
||||
// CHECK: %[[CST:.*]] = arith.constant 7.000000e+00 : f32
|
||||
|
||||
// CHECK: affine.store %[[CST]], %0[%a0] : memref<100xf32>
|
11
mlir/test/mlir-query/logical-operator-test.mlir
Normal file
11
mlir/test/mlir-query/logical-operator-test.mlir
Normal file
@ -0,0 +1,11 @@
|
||||
// RUN: mlir-query %s -c "m allOf(hasOpName(\"memref.alloca\"), hasOpAttrName(\"alignment\"))" | FileCheck %s
|
||||
|
||||
func.func @dynamic_alloca(%arg0: index, %arg1: index) -> memref<?x?xf32> {
|
||||
%0 = memref.alloca(%arg0, %arg1) : memref<?x?xf32>
|
||||
memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32>
|
||||
return %0 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK: Match #1:
|
||||
// CHECK: {{.*}}.mlir:5:3: note: "root" binds here
|
||||
// CHECK: memref.alloca(%arg0, %arg1) {alignment = 32} : memref<?x?xf32>
|
29
mlir/test/mlir-query/slice-function-extraction.mlir
Normal file
29
mlir/test/mlir-query/slice-function-extraction.mlir
Normal file
@ -0,0 +1,29 @@
|
||||
// RUN: mlir-query %s -c "m getDefinitionsByPredicate(hasOpName(\"memref.store\"),hasOpName(\"memref.alloc\"),true,false,false).extract(\"backward_slice\")" | FileCheck %s
|
||||
|
||||
// CHECK: func.func @backward_slice(%{{.*}}: memref<10xf32>) -> (f32, index, index, f32, index, index, f32) {
|
||||
// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[I0:.*]] = affine.apply affine_map<()[s0] -> (s0)>()[%[[C0]]]
|
||||
// CHECK-NEXT: memref.store %[[CST0]], %{{.*}}[%[[I0]]] : memref<10xf32>
|
||||
// CHECK-NEXT: %[[CST2:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[I1:.*]] = affine.apply affine_map<() -> (0)>()
|
||||
// CHECK-NEXT: memref.store %[[CST2]], %{{.*}}[%[[I1]]] : memref<10xf32>
|
||||
// CHECK-NEXT: %[[C1:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %{{.*}}[%[[C1]]] : memref<10xf32>
|
||||
// CHECK-NEXT: memref.store %[[LOAD]], %{{.*}}[%[[C1]]] : memref<10xf32>
|
||||
// CHECK-NEXT: return %[[CST0]], %[[C0]], %[[I0]], %[[CST2]], %[[I1]], %[[C1]], %[[LOAD]] : f32, index, index, f32, index, index, f32
|
||||
|
||||
func.func @slicing_memref_store_trivial() {
|
||||
%0 = memref.alloc() : memref<10xf32>
|
||||
%c0 = arith.constant 0 : index
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
affine.for %i1 = 0 to 10 {
|
||||
%1 = affine.apply affine_map<()[s0] -> (s0)>()[%c0]
|
||||
memref.store %cst, %0[%1] : memref<10xf32>
|
||||
%2 = memref.load %0[%c0] : memref<10xf32>
|
||||
%3 = affine.apply affine_map<()[] -> (0)>()[]
|
||||
memref.store %cst, %0[%3] : memref<10xf32>
|
||||
memref.store %2, %0[%c0] : memref<10xf32>
|
||||
}
|
||||
return
|
||||
}
|
@ -40,12 +40,22 @@ int main(int argc, char **argv) {
|
||||
query::matcher::Registry matcherRegistry;
|
||||
|
||||
// Matchers registered in alphabetical order for consistency:
|
||||
matcherRegistry.registerMatcher("allOf", query::matcher::internal::allOf);
|
||||
matcherRegistry.registerMatcher("anyOf", query::matcher::internal::anyOf);
|
||||
matcherRegistry.registerMatcher(
|
||||
"getAllDefinitions",
|
||||
query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
|
||||
matcherRegistry.registerMatcher(
|
||||
"getDefinitions",
|
||||
query::matcher::m_GetDefinitions<query::matcher::DynMatcher>);
|
||||
matcherRegistry.registerMatcher(
|
||||
"getAllDefinitions",
|
||||
query::matcher::m_GetAllDefinitions<query::matcher::DynMatcher>);
|
||||
"getDefinitionsByPredicate",
|
||||
query::matcher::m_GetDefinitionsByPredicate<query::matcher::DynMatcher,
|
||||
query::matcher::DynMatcher>);
|
||||
matcherRegistry.registerMatcher(
|
||||
"getUsersByPredicate",
|
||||
query::matcher::m_GetUsersByPredicate<query::matcher::DynMatcher,
|
||||
query::matcher::DynMatcher>);
|
||||
matcherRegistry.registerMatcher("hasOpAttrName",
|
||||
static_cast<HasOpAttrName *>(m_Attr));
|
||||
matcherRegistry.registerMatcher("hasOpName", static_cast<HasOpName *>(m_Op));
|
||||
|
Loading…
x
Reference in New Issue
Block a user